Skip to content

Commit e84562d

Browse files
committed
cgemm: enable fma for neon
1 parent e6d04e1 commit e84562d

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,13 @@ jobs:
123123
include:
124124
- rust: stable
125125
target: aarch64-unknown-linux-gnu
126+
features: constconf cgemm threading
126127
- rust: 1.61.0
127128
target: aarch64-unknown-linux-gnu
129+
features: cgemm
128130
- rust: 1.41.1 # MSRV
129131
target: aarch64-unknown-linux-gnu
132+
features: cgemm
130133

131134
steps:
132135
- uses: actions/checkout@v2
@@ -146,7 +149,7 @@ jobs:
146149
if: steps.cache.outputs.cache-hit != 'true'
147150
run: cargo install cross
148151
- name: Tests
149-
run: cross test --target "${{ matrix.target }}"
152+
run: cross test --target "${{ matrix.target }}" --features "${{ matrix.features }}"
150153
env:
151154
MMTEST_FAST_TEST: 1
152155
RUSTFLAGS: -Copt-level=2

src/cgemm_kernel.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ struct KernelAvx2;
1717
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
1818
struct KernelFma;
1919

20+
#[cfg(target_arch = "aarch64")]
21+
#[cfg(has_aarch64_simd)]
22+
struct KernelNeon;
23+
2024
struct KernelFallback;
2125

2226
type T = c32;
@@ -39,6 +43,13 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3943
return selector.select(KernelFma);
4044
}
4145
}
46+
#[cfg(target_arch = "aarch64")]
47+
#[cfg(has_aarch64_simd)]
48+
{
49+
if is_aarch64_feature_detected_!("neon") {
50+
return selector.select(KernelNeon);
51+
}
52+
}
4253
return selector.select(KernelFallback);
4354
}
4455

@@ -110,6 +121,41 @@ impl GemmKernel for KernelFma {
110121
}
111122
}
112123

124+
#[cfg(target_arch = "aarch64")]
125+
#[cfg(has_aarch64_simd)]
126+
impl GemmKernel for KernelNeon {
127+
type Elem = T;
128+
129+
type MRTy = U4;
130+
type NRTy = U2;
131+
132+
#[inline(always)]
133+
fn align_to() -> usize { 16 }
134+
135+
#[inline(always)]
136+
fn always_masked() -> bool { KernelFallback::always_masked() }
137+
138+
#[inline(always)]
139+
fn nc() -> usize { archparam::C_NC }
140+
#[inline(always)]
141+
fn kc() -> usize { archparam::C_KC }
142+
#[inline(always)]
143+
fn mc() -> usize { archparam::C_MC }
144+
145+
pack_methods!{}
146+
147+
#[inline(always)]
148+
unsafe fn kernel(
149+
k: usize,
150+
alpha: T,
151+
a: *const T,
152+
b: *const T,
153+
beta: T,
154+
c: *mut T, rsc: isize, csc: isize) {
155+
kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
156+
}
157+
}
158+
113159
impl GemmKernel for KernelFallback {
114160
type Elem = T;
115161

@@ -170,6 +216,22 @@ kernel_fallback_impl_complex! {
170216
kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
171217
}
172218

219+
// Kernel neon
220+
221+
#[cfg(target_arch = "aarch64")]
222+
#[cfg(has_aarch64_simd)]
223+
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
224+
#[cfg(target_arch = "aarch64")]
225+
#[cfg(has_aarch64_simd)]
226+
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
227+
228+
#[cfg(target_arch = "aarch64")]
229+
#[cfg(has_aarch64_simd)]
230+
kernel_fallback_impl_complex! {
231+
[inline target_feature(enable="neon")] [fma_yes]
232+
kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
233+
}
234+
173235
// Kernel fallback
174236

175237
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
@@ -195,6 +257,34 @@ mod tests {
195257
test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
196258
}
197259

260+
#[cfg(target_arch = "aarch64")]
261+
#[cfg(has_aarch64_simd)]
262+
mod test_kernel_aarch64 {
263+
use super::test_complex_packed_kernel;
264+
use super::super::*;
265+
#[cfg(feature = "std")]
266+
use std::println;
267+
macro_rules! test_arch_kernels {
268+
($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
269+
$(
270+
#[test]
271+
fn $name() {
272+
if is_aarch64_feature_detected_!($feature_name) {
273+
test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
274+
} else {
275+
#[cfg(feature = "std")]
276+
println!("Skipping, host does not have feature: {:?}", $feature_name);
277+
}
278+
}
279+
)*
280+
}
281+
}
282+
283+
test_arch_kernels! {
284+
"neon", neon, KernelNeon
285+
}
286+
}
287+
198288
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
199289
mod test_arch_kernels {
200290
use super::test_complex_packed_kernel;

src/zgemm_kernel.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ struct KernelAvx2;
1717
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
1818
struct KernelFma;
1919

20+
#[cfg(target_arch = "aarch64")]
21+
#[cfg(has_aarch64_simd)]
22+
struct KernelNeon;
23+
2024
struct KernelFallback;
2125

2226
type T = c64;
@@ -39,6 +43,13 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3943
return selector.select(KernelFma);
4044
}
4145
}
46+
#[cfg(target_arch = "aarch64")]
47+
#[cfg(has_aarch64_simd)]
48+
{
49+
if is_aarch64_feature_detected_!("neon") {
50+
return selector.select(KernelNeon);
51+
}
52+
}
4253
return selector.select(KernelFallback);
4354
}
4455

@@ -113,6 +124,41 @@ impl GemmKernel for KernelFma {
113124
}
114125
}
115126

127+
#[cfg(target_arch = "aarch64")]
128+
#[cfg(has_aarch64_simd)]
129+
impl GemmKernel for KernelNeon {
130+
type Elem = T;
131+
132+
type MRTy = U4;
133+
type NRTy = U2;
134+
135+
#[inline(always)]
136+
fn align_to() -> usize { 16 }
137+
138+
#[inline(always)]
139+
fn always_masked() -> bool { KernelFallback::always_masked() }
140+
141+
#[inline(always)]
142+
fn nc() -> usize { archparam::Z_NC }
143+
#[inline(always)]
144+
fn kc() -> usize { archparam::Z_KC }
145+
#[inline(always)]
146+
fn mc() -> usize { archparam::Z_MC }
147+
148+
pack_methods!{}
149+
150+
#[inline(always)]
151+
unsafe fn kernel(
152+
k: usize,
153+
alpha: T,
154+
a: *const T,
155+
b: *const T,
156+
beta: T,
157+
c: *mut T, rsc: isize, csc: isize) {
158+
kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
159+
}
160+
}
161+
116162
impl GemmKernel for KernelFallback {
117163
type Elem = T;
118164

@@ -160,6 +206,17 @@ kernel_fallback_impl_complex! {
160206
kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
161207
}
162208

209+
// Kernel neon
210+
211+
#[cfg(target_arch = "aarch64")]
212+
#[cfg(has_aarch64_simd)]
213+
kernel_fallback_impl_complex! {
214+
[inline target_feature(enable="neon")] [fma_yes]
215+
kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
216+
}
217+
218+
// kernel fallback
219+
163220
kernel_fallback_impl_complex! {
164221
[inline] [fma_no]
165222
kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
@@ -180,6 +237,34 @@ mod tests {
180237
test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
181238
}
182239

240+
#[cfg(target_arch = "aarch64")]
241+
#[cfg(has_aarch64_simd)]
242+
mod test_kernel_aarch64 {
243+
use super::test_complex_packed_kernel;
244+
use super::super::*;
245+
#[cfg(feature = "std")]
246+
use std::println;
247+
macro_rules! test_arch_kernels {
248+
($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
249+
$(
250+
#[test]
251+
fn $name() {
252+
if is_aarch64_feature_detected_!($feature_name) {
253+
test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
254+
} else {
255+
#[cfg(feature = "std")]
256+
println!("Skipping, host does not have feature: {:?}", $feature_name);
257+
}
258+
}
259+
)*
260+
}
261+
}
262+
263+
test_arch_kernels! {
264+
"neon", neon, KernelNeon
265+
}
266+
}
267+
183268
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
184269
mod test_arch_kernels {
185270
use super::test_complex_packed_kernel;

0 commit comments

Comments
 (0)