@@ -3,6 +3,7 @@ extern crate defmac;
3
3
extern crate ndarray;
4
4
extern crate num_traits;
5
5
extern crate blas_src;
6
+ extern crate num_complex;
6
7
7
8
use ndarray:: prelude:: * ;
8
9
@@ -12,6 +13,8 @@ use ndarray::{Data, Ix, LinalgScalar};
12
13
13
14
use approx:: assert_relative_eq;
14
15
use defmac:: defmac;
16
+ use num_complex:: Complex32 ;
17
+ use num_complex:: Complex64 ;
15
18
16
19
#[ test]
17
20
fn mat_vec_product_1d ( ) {
@@ -52,6 +55,20 @@ fn range_mat64(m: Ix, n: Ix) -> Array2<f64> {
52
55
. unwrap ( )
53
56
}
54
57
58
+ fn range_mat_complex ( m : Ix , n : Ix ) -> Array2 < Complex32 > {
59
+ Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
60
+ . into_shape ( ( m, n) )
61
+ . unwrap ( )
62
+ . map_mut ( |& mut f| Complex32 :: new ( f, 0. ) )
63
+ }
64
+
65
+ fn range_mat_complex64 ( m : Ix , n : Ix ) -> Array2 < Complex64 > {
66
+ Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
67
+ . into_shape ( ( m, n) )
68
+ . unwrap ( )
69
+ . map_mut ( |& mut f| Complex64 :: new ( f, 0. ) )
70
+ }
71
+
55
72
fn range1_mat64 ( m : Ix ) -> Array1 < f64 > {
56
73
Array :: linspace ( 0. , m as f64 - 1. , m)
57
74
}
@@ -250,6 +267,30 @@ fn gemm_64_1_f() {
250
267
assert_relative_eq ! ( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
251
268
}
252
269
270
+ #[ test]
271
+ fn gemm_c64_1_f ( ) {
272
+ let a = range_mat_complex64 ( 64 , 64 ) . reversed_axes ( ) ;
273
+ let ( m, n) = a. dim ( ) ;
274
+ // m x n times n x 1 == m x 1
275
+ let x = range_mat_complex64 ( n, 1 ) ;
276
+ let mut y = range_mat_complex64 ( m, 1 ) ;
277
+ let answer = reference_mat_mul ( & a, & x) + & y;
278
+ general_mat_mul ( Complex64 :: new ( 1.0 , 0. ) , & a, & x, Complex64 :: new ( 1.0 , 0. ) , & mut y) ;
279
+ assert_relative_eq ! ( y. mapv( |i| i. norm_sqr( ) ) , answer. mapv( |i| i. norm_sqr( ) ) , epsilon = 1e-12 , max_relative = 1e-7 ) ;
280
+ }
281
+
282
+ #[ test]
283
+ fn gemm_c32_1_f ( ) {
284
+ let a = range_mat_complex ( 64 , 64 ) . reversed_axes ( ) ;
285
+ let ( m, n) = a. dim ( ) ;
286
+ // m x n times n x 1 == m x 1
287
+ let x = range_mat_complex ( n, 1 ) ;
288
+ let mut y = range_mat_complex ( m, 1 ) ;
289
+ let answer = reference_mat_mul ( & a, & x) + & y;
290
+ general_mat_mul ( Complex32 :: new ( 1.0 , 0. ) , & a, & x, Complex32 :: new ( 1.0 , 0. ) , & mut y) ;
291
+ assert_relative_eq ! ( y. mapv( |i| i. norm_sqr( ) ) , answer. mapv( |i| i. norm_sqr( ) ) , epsilon = 1e-12 , max_relative = 1e-7 ) ;
292
+ }
293
+
253
294
#[ test]
254
295
fn gen_mat_vec_mul ( ) {
255
296
let alpha = -2.3 ;
0 commit comments