Skip to content

Commit 6f86fd9

Browse files
committed
cgemm: Setup Avx2 and Fma autovectorized kernels
Custom sizes for Fma and Avx2 is a win for performance, and Avx2 does better than fma here, so both can be worthwhile.
1 parent 9896879 commit 6f86fd9

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

src/cgemm_kernel.rs

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
1212
use crate::archparam;
1313
use crate::cgemm_common::pack_complex;
1414

15+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16+
struct KernelAvx2;
1517
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
1618
struct KernelFma;
1719

@@ -30,22 +32,56 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3032
// dispatch to specific compiled versions
3133
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
3234
{
35+
if is_x86_feature_detected_!("avx2") {
36+
return selector.select(KernelAvx2);
37+
}
3338
if is_x86_feature_detected_!("fma") {
3439
return selector.select(KernelFma);
3540
}
3641
}
3742
return selector.select(KernelFallback);
3843
}
3944

40-
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
41-
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
45+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
46+
impl GemmKernel for KernelAvx2 {
47+
type Elem = T;
48+
49+
type MRTy = U4;
50+
type NRTy = U4;
51+
52+
#[inline(always)]
53+
fn align_to() -> usize { 32 }
54+
55+
#[inline(always)]
56+
fn always_masked() -> bool { KernelFallback::always_masked() }
57+
58+
#[inline(always)]
59+
fn nc() -> usize { archparam::C_NC }
60+
#[inline(always)]
61+
fn kc() -> usize { archparam::C_KC }
62+
#[inline(always)]
63+
fn mc() -> usize { archparam::C_MC }
64+
65+
pack_methods!{}
66+
67+
#[inline(always)]
68+
unsafe fn kernel(
69+
k: usize,
70+
alpha: T,
71+
a: *const T,
72+
b: *const T,
73+
beta: T,
74+
c: *mut T, rsc: isize, csc: isize) {
75+
kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
76+
}
77+
}
4278

4379
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
4480
impl GemmKernel for KernelFma {
4581
type Elem = T;
4682

47-
type MRTy = <KernelFallback as GemmKernel>::MRTy;
48-
type NRTy = <KernelFallback as GemmKernel>::NRTy;
83+
type MRTy = U4;
84+
type NRTy = U4;
4985

5086
#[inline(always)]
5187
fn align_to() -> usize { 16 }
@@ -107,13 +143,36 @@ impl GemmKernel for KernelFallback {
107143
}
108144
}
109145

146+
// Kernel AVX2
147+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
148+
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
149+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
150+
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
110151

111152
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
112153
kernel_fallback_impl_complex! {
113-
// instantiate fma separately to use an unroll count that works better here
114-
[inline target_feature(enable="fma")] kernel_target_fma, T, TReal, KernelFallback::MR, KernelFallback::NR, 2
154+
// instantiate fma separately
155+
[inline target_feature(enable="avx2")] kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 1
115156
}
116157

158+
159+
// Kernel Fma
160+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
161+
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
162+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
163+
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
164+
165+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
166+
kernel_fallback_impl_complex! {
167+
// instantiate fma separately
168+
[inline target_feature(enable="fma")] kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
169+
}
170+
171+
// Kernel fallback
172+
173+
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
174+
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
175+
117176
kernel_fallback_impl_complex! { [inline(always)] kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1 }
118177

119178
#[inline(always)]
@@ -154,7 +213,8 @@ mod tests {
154213
}
155214

156215
test_arch_kernels_x86! {
157-
"fma", fma, KernelFma
216+
"fma", fma, KernelFma,
217+
"avx2", avx2, KernelAvx2
158218
}
159219
}
160220
}

src/zgemm_kernel.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use crate::kernel::{U2, U4, c64, Element, c64_mul as mul};
1212
use crate::archparam;
1313
use crate::cgemm_common::pack_complex;
1414

15+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16+
struct KernelAvx2;
1517
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
1618
struct KernelFma;
1719

@@ -31,6 +33,9 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
3133
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
3234
{
3335
if is_x86_feature_detected_!("fma") {
36+
if is_x86_feature_detected_!("avx2") {
37+
return selector.select(KernelAvx2);
38+
}
3439
return selector.select(KernelFma);
3540
}
3641
}
@@ -40,6 +45,40 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
4045
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
4146
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
4247

48+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
49+
impl GemmKernel for KernelAvx2 {
50+
type Elem = T;
51+
52+
type MRTy = U4;
53+
type NRTy = U2;
54+
55+
#[inline(always)]
56+
fn align_to() -> usize { 32 }
57+
58+
#[inline(always)]
59+
fn always_masked() -> bool { KernelFallback::always_masked() }
60+
61+
#[inline(always)]
62+
fn nc() -> usize { archparam::Z_NC }
63+
#[inline(always)]
64+
fn kc() -> usize { archparam::Z_KC }
65+
#[inline(always)]
66+
fn mc() -> usize { archparam::Z_MC }
67+
68+
pack_methods!{}
69+
70+
#[inline(always)]
71+
unsafe fn kernel(
72+
k: usize,
73+
alpha: T,
74+
a: *const T,
75+
b: *const T,
76+
beta: T,
77+
c: *mut T, rsc: isize, csc: isize) {
78+
kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
79+
}
80+
}
81+
4382
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
4483
impl GemmKernel for KernelFma {
4584
type Elem = T;
@@ -109,10 +148,17 @@ impl GemmKernel for KernelFallback {
109148

110149
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
111150
kernel_fallback_impl_complex! {
112-
// instantiate fma separately so that it's inlined here
113-
[inline target_feature(enable="fma")] kernel_target_fma, T, TReal, KernelFallback::MR, KernelFallback::NR, 2
151+
// instantiate fma separately
152+
[inline target_feature(enable="fma") target_feature(enable="avx2")]
153+
kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
114154
}
115155

156+
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
157+
kernel_fallback_impl_complex! {
158+
// instantiate fma separately
159+
[inline target_feature(enable="fma")]
160+
kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
161+
}
116162

117163
kernel_fallback_impl_complex! { [inline] kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1 }
118164

@@ -154,7 +200,8 @@ mod tests {
154200
}
155201

156202
test_arch_kernels_x86! {
157-
"fma", fma, KernelFma
203+
"fma", fma, KernelFma,
204+
"avx2", avx2, KernelAvx2
158205
}
159206
}
160207
}

0 commit comments

Comments
 (0)