Skip to content

Commit 9403f9b

Browse files
Nicoshevfacebook-github-bot
authored andcommitted
Add sve implementation for float matrix transpose (#3421)
Summary: X-link: facebookresearch/FBGEMM#509 Adding sve-based function for transposing float matrixes Differential Revision: D66528598
1 parent cffa05a commit 9403f9b

File tree

5 files changed

+864
-2
lines changed

5 files changed

+864
-2
lines changed

defs.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False):
143143
return asm_srcs if not msvc else intrinsics_srcs
144144

145145
def get_fbgemm_inline_sve_srcs(msvc = False, buck = False):
146-
intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]
146+
intrinsics_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"]
147147

148148
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
149-
asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc"]
149+
asm_srcs = ["src/FbgemmFP16UKernelsSve128.cc", "src/UtilsSve.cc"]
150150
if buck:
151151
return select({
152152
"DEFAULT": asm_srcs,

src/TransposeUtils.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ void transpose_simd(
4646
}
4747
return;
4848
}
49+
50+
#ifdef __aarch64__
51+
if constexpr (std::is_same<T, float>::value) {
52+
internal::transpose_sve<T>(M, N, src, ld_src, dst, ld_dst);
53+
} else {
54+
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
55+
}
56+
#else
4957
static const auto iset = fbgemmInstructionSet();
5058
// Run time CPU detection
5159
if (isZmm(iset)) {
@@ -55,6 +63,8 @@ void transpose_simd(
5563
} else {
5664
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
5765
}
66+
67+
#endif
5868
}
5969

6070
template void transpose_ref<float>(

src/TransposeUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,22 @@ void transpose_avx512(
6262
T* dst,
6363
int64_t ld_dst);
6464

65+
#ifdef __aarch64__
66+
/**
67+
* @brief Transpose a matrix using Intel AVX2.
68+
*
69+
* This is called if the code is running on a CPU with Intel AVX2 support.
70+
*/
71+
template <typename T>
72+
void transpose_sve(
73+
int64_t M,
74+
int64_t N,
75+
const T* src,
76+
int64_t ld_src,
77+
T* dst,
78+
int64_t ld_dst);
79+
#endif // __aarch64__
80+
6581
} // namespace internal
6682

6783
} // namespace fbgemm

0 commit comments

Comments
 (0)