Skip to content

Commit

Permalink
Don't require -march compiler flags to use bfdot (#5444)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5444

TIL about the `target` clang/GCC function attribute, which allows building a particular function under an `-march` flag instead of a whole file.
ghstack-source-id: 243858419

Reviewed By: malfet

Differential Revision: D62905047

fbshipit-source-id: a89c8169fea315aa653bbca819a672357c3dff77
  • Loading branch information
swolchok authored and facebook-github-bot committed Sep 21, 2024
1 parent c50f9fe commit 0eee42a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 54 deletions.
116 changes: 76 additions & 40 deletions kernels/optimized/blas/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,43 +74,60 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
}

#ifdef __ARM_FEATURE_BF16
static ET_INLINE float32x4_t
#define ET_TARGET_ARM_BF16_ATTRIBUTE \
__attribute__((target("arch=armv8.2-a+bf16")))
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
return vbfdotq_f32(a, b, c);
}
#endif // __ARM_FEATURE_BF16

template <bool useBfloat16Dot>
static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
}

static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
}

template <bool useBfdot>
ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
#ifdef __ARM_FEATURE_BF16
if (useBfloat16Dot) {
const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
sum[registerPairIndex] =
f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
} else
#endif // __ARM_FEATURE_BF16
{
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
if constexpr (useBfdot) {
dot_with_fp32_arith_main_inner_loop_bfdot(
vec1, vec2, sum, registerPairIndex);
} else {
dot_with_fp32_arith_main_inner_loop_no_bfdot(
vec1, vec2, sum, registerPairIndex);
}
}

Expand All @@ -126,18 +143,40 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
}

template <typename T, bool useBfloat16Dot>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
namespace {
template <int n>
struct ForcedUnrollTargetBFloat16 {
template <typename Func>
ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
ForcedUnrollTargetBFloat16<n - 1>{}(f);
f(n - 1);
}
};

template <>
struct ForcedUnrollTargetBFloat16<1> {
template <typename Func>
ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
f(0);
}
};

} // namespace

template <typename T, bool useBFloat16Dot>
ET_TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop<useBfloat16Dot>(
vec1_, vec2_, sum, k);
});
ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k)
ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);

Expand All @@ -163,12 +202,9 @@ float bf16_dot_with_fp32_arith(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
#ifdef __ARM_FEATURE_BF16
if (cpuinfo_has_arm_bf16()) {
return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
} else
#endif // __ARM_FEATURE_BF16
{
} else {
return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
}
}
Expand Down
14 changes: 0 additions & 14 deletions kernels/optimized/lib_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ def define_libs():
] if not runtime.is_oss else [],
"DEFAULT": [],
}),
fbandroid_platform_compiler_flags = [
(
"^android-arm64.*$",
[
"-march=armv8+bf16",
],
),
],
fbandroid_platform_preprocessor_flags = [
(
"^android-arm64.*$",
Expand All @@ -156,12 +148,6 @@ def define_libs():
],
),
],
fbobjc_platform_compiler_flags = [
(
".*arm64.*",
["-march=armv8+bf16"],
),
],
fbobjc_exported_preprocessor_flags = [
"-DET_BUILD_WITH_BLAS",
"-DET_BUILD_FOR_APPLE",
Expand Down

0 comments on commit 0eee42a

Please sign in to comment.