@@ -12,6 +12,8 @@ use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
1212use crate :: archparam;
1313use crate :: cgemm_common:: pack_complex;
1414
15+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
16+ struct KernelAvx2 ;
1517#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
1618struct KernelFma ;
1719
@@ -30,22 +32,56 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3032 // dispatch to specific compiled versions
3133 #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
3234 {
35+ if is_x86_feature_detected_ ! ( "avx2" ) {
36+ return selector. select ( KernelAvx2 ) ;
37+ }
3338 if is_x86_feature_detected_ ! ( "fma" ) {
3439 return selector. select ( KernelFma ) ;
3540 }
3641 }
3742 return selector. select ( KernelFallback ) ;
3843}
3944
40- macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
41- macro_rules! loop_n { ( $j: ident, $e: expr) => { loop2!( $j, $e) } ; }
45+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
46+ impl GemmKernel for KernelAvx2 {
47+ type Elem = T ;
48+
49+ type MRTy = U4 ;
50+ type NRTy = U4 ;
51+
52+ #[ inline( always) ]
53+ fn align_to ( ) -> usize { 32 }
54+
55+ #[ inline( always) ]
56+ fn always_masked ( ) -> bool { KernelFallback :: always_masked ( ) }
57+
58+ #[ inline( always) ]
59+ fn nc ( ) -> usize { archparam:: C_NC }
60+ #[ inline( always) ]
61+ fn kc ( ) -> usize { archparam:: C_KC }
62+ #[ inline( always) ]
63+ fn mc ( ) -> usize { archparam:: C_MC }
64+
65+ pack_methods ! { }
66+
67+ #[ inline( always) ]
68+ unsafe fn kernel (
69+ k : usize ,
70+ alpha : T ,
71+ a : * const T ,
72+ b : * const T ,
73+ beta : T ,
74+ c : * mut T , rsc : isize , csc : isize ) {
75+ kernel_target_avx2 ( k, alpha, a, b, beta, c, rsc, csc)
76+ }
77+ }
4278
4379#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
4480impl GemmKernel for KernelFma {
4581 type Elem = T ;
4682
47- type MRTy = < KernelFallback as GemmKernel > :: MRTy ;
48- type NRTy = < KernelFallback as GemmKernel > :: NRTy ;
83+ type MRTy = U4 ;
84+ type NRTy = U4 ;
4985
5086 #[ inline( always) ]
5187 fn align_to ( ) -> usize { 16 }
@@ -107,13 +143,36 @@ impl GemmKernel for KernelFallback {
107143 }
108144}
109145
146+ // Kernel AVX2
147+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
148+ macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
149+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
150+ macro_rules! loop_n { ( $j: ident, $e: expr) => { loop4!( $j, $e) } ; }
110151
111152#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
112153kernel_fallback_impl_complex ! {
113- // instantiate fma separately to use an unroll count that works better here
114- [ inline target_feature( enable="fma " ) ] kernel_target_fma , T , TReal , KernelFallback :: MR , KernelFallback :: NR , 2
154+ // instantiate fma separately
155+ [ inline target_feature( enable="avx2 " ) ] kernel_target_avx2 , T , TReal , KernelAvx2 :: MR , KernelAvx2 :: NR , 1
115156}
116157
158+
159+ // Kernel Fma
160+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
161+ macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
162+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
163+ macro_rules! loop_n { ( $j: ident, $e: expr) => { loop4!( $j, $e) } ; }
164+
165+ #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
166+ kernel_fallback_impl_complex ! {
167+ // instantiate fma separately
168+ [ inline target_feature( enable="fma" ) ] kernel_target_fma, T , TReal , KernelFma :: MR , KernelFma :: NR , 2
169+ }
170+
171+ // Kernel fallback
172+
173+ macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
174+ macro_rules! loop_n { ( $j: ident, $e: expr) => { loop2!( $j, $e) } ; }
175+
117176kernel_fallback_impl_complex ! { [ inline( always) ] kernel_fallback_impl, T , TReal , KernelFallback :: MR , KernelFallback :: NR , 1 }
118177
119178#[ inline( always) ]
@@ -154,7 +213,8 @@ mod tests {
154213 }
155214
156215 test_arch_kernels_x86 ! {
157- "fma" , fma, KernelFma
216+ "fma" , fma, KernelFma ,
217+ "avx2" , avx2, KernelAvx2
158218 }
159219 }
160220}
0 commit comments