@@ -17,6 +17,10 @@ struct KernelAvx2;
1717#[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
1818struct KernelFma ;
1919
20+ #[ cfg( target_arch = "aarch64" ) ]
21+ #[ cfg( has_aarch64_simd) ]
22+ struct KernelNeon ;
23+
2024struct KernelFallback ;
2125
2226type T = c32 ;
@@ -39,6 +43,13 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3943 return selector. select ( KernelFma ) ;
4044 }
4145 }
46+ #[ cfg( target_arch = "aarch64" ) ]
47+ #[ cfg( has_aarch64_simd) ]
48+ {
49+ if is_aarch64_feature_detected_ ! ( "neon" ) {
50+ return selector. select ( KernelNeon ) ;
51+ }
52+ }
4253 return selector. select ( KernelFallback ) ;
4354}
4455
@@ -110,6 +121,41 @@ impl GemmKernel for KernelFma {
110121 }
111122}
112123
124+ #[ cfg( target_arch = "aarch64" ) ]
125+ #[ cfg( has_aarch64_simd) ]
126+ impl GemmKernel for KernelNeon {
127+ type Elem = T ;
128+
129+ type MRTy = U4 ;
130+ type NRTy = U2 ;
131+
132+ #[ inline( always) ]
133+ fn align_to ( ) -> usize { 16 }
134+
135+ #[ inline( always) ]
136+ fn always_masked ( ) -> bool { KernelFallback :: always_masked ( ) }
137+
138+ #[ inline( always) ]
139+ fn nc ( ) -> usize { archparam:: C_NC }
140+ #[ inline( always) ]
141+ fn kc ( ) -> usize { archparam:: C_KC }
142+ #[ inline( always) ]
143+ fn mc ( ) -> usize { archparam:: C_MC }
144+
145+ pack_methods ! { }
146+
147+ #[ inline( always) ]
148+ unsafe fn kernel (
149+ k : usize ,
150+ alpha : T ,
151+ a : * const T ,
152+ b : * const T ,
153+ beta : T ,
154+ c : * mut T , rsc : isize , csc : isize ) {
155+ kernel_target_neon ( k, alpha, a, b, beta, c, rsc, csc)
156+ }
157+ }
158+
113159impl GemmKernel for KernelFallback {
114160 type Elem = T ;
115161
@@ -170,6 +216,22 @@ kernel_fallback_impl_complex! {
170216 kernel_target_fma, T , TReal , KernelFma :: MR , KernelFma :: NR , 2
171217}
172218
219+ // Kernel neon
220+
221+ #[ cfg( target_arch = "aarch64" ) ]
222+ #[ cfg( has_aarch64_simd) ]
223+ macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
224+ #[ cfg( target_arch = "aarch64" ) ]
225+ #[ cfg( has_aarch64_simd) ]
226+ macro_rules! loop_n { ( $j: ident, $e: expr) => { loop2!( $j, $e) } ; }
227+
228+ #[ cfg( target_arch = "aarch64" ) ]
229+ #[ cfg( has_aarch64_simd) ]
230+ kernel_fallback_impl_complex ! {
231+ [ inline target_feature( enable="neon" ) ] [ fma_yes]
232+ kernel_target_neon, T , TReal , KernelNeon :: MR , KernelNeon :: NR , 1
233+ }
234+
173235// Kernel fallback
174236
175237macro_rules! loop_m { ( $i: ident, $e: expr) => { loop4!( $i, $e) } ; }
@@ -195,6 +257,34 @@ mod tests {
195257 test_complex_packed_kernel :: < KernelFallback , _ , TReal > ( "kernel" ) ;
196258 }
197259
260+ #[ cfg( target_arch = "aarch64" ) ]
261+ #[ cfg( has_aarch64_simd) ]
262+ mod test_kernel_aarch64 {
263+ use super :: test_complex_packed_kernel;
264+ use super :: super :: * ;
265+ #[ cfg( feature = "std" ) ]
266+ use std:: println;
267+ macro_rules! test_arch_kernels {
268+ ( $( $feature_name: tt, $name: ident, $kernel_ty: ty) ,* ) => {
269+ $(
270+ #[ test]
271+ fn $name( ) {
272+ if is_aarch64_feature_detected_!( $feature_name) {
273+ test_complex_packed_kernel:: <$kernel_ty, _, TReal >( stringify!( $name) ) ;
274+ } else {
275+ #[ cfg( feature = "std" ) ]
276+ println!( "Skipping, host does not have feature: {:?}" , $feature_name) ;
277+ }
278+ }
279+ ) *
280+ }
281+ }
282+
283+ test_arch_kernels ! {
284+ "neon" , neon, KernelNeon
285+ }
286+ }
287+
198288 #[ cfg( any( target_arch="x86" , target_arch="x86_64" ) ) ]
199289 mod test_arch_kernels {
200290 use super :: test_complex_packed_kernel;
0 commit comments