66// option. This file may not be copied, modified, or distributed
77// except according to those terms.
88
9+ use rayon;
910use libnum:: Zero ;
1011use itertools:: free:: enumerate;
1112
@@ -417,6 +418,8 @@ fn mat_mul_impl<A>(alpha: A,
417418 mat_mul_general ( alpha, lhs, rhs, beta, c)
418419}
419420
421+ const SPLIT : usize = 64 ;
422+
420423/// C ← α A B + β C
421424fn mat_mul_general < A > ( alpha : A ,
422425 lhs : & ArrayView < A , ( Ix , Ix ) > ,
@@ -425,7 +428,27 @@ fn mat_mul_general<A>(alpha: A,
425428 c : & mut ArrayViewMut < A , ( Ix , Ix ) > )
426429 where A : LinalgScalar ,
427430{
428- let ( ( m, k) , ( _, n) ) = ( lhs. dim , rhs. dim ) ;
431+ let ( ( m, k) , ( k2, n) ) = ( lhs. dim , rhs. dim ) ;
432+
433+ debug_assert_eq ! ( k, k2) ;
434+ if m > SPLIT {
435+ // [ A0 ] B = [ C0 ]
436+ // [ A1 ] [ C1 ]
437+ let mid = m / 2 ;
438+ let ( a0, a1) = lhs. split_at ( Axis ( 0 ) , mid) ;
439+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 0 ) , mid) ;
440+ rayon:: join ( move || mat_mul_general ( alpha, & a0, rhs, beta, & mut c0) ,
441+ move || mat_mul_general ( alpha, & a1, rhs, beta, & mut c1) ) ;
442+ return ;
443+ } else if n > SPLIT {
444+ // A [ B0 B1 ] = [ C0 C1 ]
445+ let mid = n / 2 ;
446+ let ( b0, b1) = rhs. split_at ( Axis ( 1 ) , mid) ;
447+ let ( mut c0, mut c1) = c. view_mut ( ) . split_at ( Axis ( 1 ) , mid) ;
448+ rayon:: join ( move || mat_mul_general ( alpha, lhs, & b0, beta, & mut c0) ,
449+ move || mat_mul_general ( alpha, lhs, & b1, beta, & mut c1) ) ;
450+ return ;
451+ }
429452
430453 // common parameters for gemm
431454 let ap = lhs. as_ptr ( ) ;
0 commit comments