From 0eee42a6b859967346dc41f832cd29e68169ddd5 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 20 Sep 2024 19:32:26 -0700 Subject: [PATCH] Don't require -march compiler flags to use bfdot (#5444) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5444 TIL about the `target` clang/GCC function attribute, which allows building a particular function under an `-march` flag instead of a whole file. ghstack-source-id: 243858419 Reviewed By: malfet Differential Revision: D62905047 fbshipit-source-id: a89c8169fea315aa653bbca819a672357c3dff77 --- kernels/optimized/blas/BlasKernel.cpp | 116 +++++++++++++++++--------- kernels/optimized/lib_defs.bzl | 14 ---- 2 files changed, 76 insertions(+), 54 deletions(-) diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp index cfee709ae6..a3e2172504 100644 --- a/kernels/optimized/blas/BlasKernel.cpp +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -74,43 +74,60 @@ f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } -#ifdef __ARM_FEATURE_BF16 -static ET_INLINE float32x4_t +#define ET_TARGET_ARM_BF16_ATTRIBUTE \ + __attribute__((target("arch=armv8.2-a+bf16"))) +ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { return vbfdotq_f32(a, b, c); } -#endif // __ARM_FEATURE_BF16 -template -static ET_INLINE void dot_with_fp32_arith_main_inner_loop( +ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void +dot_with_fp32_arith_main_inner_loop_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); +} + +static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); +} + +template +ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void +dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { -#ifdef __ARM_FEATURE_BF16 - if (useBfloat16Dot) { - const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - sum[registerPairIndex] = - f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); - } else -#endif // __ARM_FEATURE_BF16 - { - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( - &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( - &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); - - sum[2 * registerPairIndex] = f32_fma_bf16( - sum[2 * registerPairIndex], - vget_low_u16(temp_vec1), - vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16( - sum[2 * registerPairIndex + 1], - vget_high_u16(temp_vec1), - vget_high_u16(temp_vec2)); + if constexpr (useBfdot) { + dot_with_fp32_arith_main_inner_loop_bfdot( + vec1, vec2, sum, registerPairIndex); + } else { + dot_with_fp32_arith_main_inner_loop_no_bfdot( + vec1, vec2, sum, registerPairIndex); } } @@ -126,18 +143,40 @@ static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } -template -float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { +namespace { +template +struct ForcedUnrollTargetBFloat16 { + template + ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + ForcedUnrollTargetBFloat16{}(f); + f(n - 1); + } +}; + +template <> +struct ForcedUnrollTargetBFloat16<1> { + template + ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { + f(0); + } +}; + +} // namespace + +template +ET_TARGET_ARM_BF16_ATTRIBUTE float +dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; - utils::ForcedUnroll{}( - [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE { - dot_with_fp32_arith_main_inner_loop( - vec1_, vec2_, sum, k); - }); + ForcedUnrollTargetBFloat16{}( + [vec1_, vec2_, &sum](auto k) + ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop( + vec1_, vec2_, sum, k); + }); } auto reducedSum = reduce(sum); @@ -163,12 +202,9 @@ float bf16_dot_with_fp32_arith( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { -#ifdef __ARM_FEATURE_BF16 if (cpuinfo_has_arm_bf16()) { return dot_with_fp32_arith(vec1, vec2, len); - } else -#endif // __ARM_FEATURE_BF16 - { + } else { return dot_with_fp32_arith(vec1, vec2, len); } } diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 11a5f3a966..367c23f081 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -132,14 +132,6 @@ def define_libs(): ] if not runtime.is_oss else [], "DEFAULT": [], }), - fbandroid_platform_compiler_flags = [ - ( - "^android-arm64.*$", - [ - "-march=armv8+bf16", - ], - ), - ], fbandroid_platform_preprocessor_flags = [ ( "^android-arm64.*$", @@ -156,12 +148,6 @@ def define_libs(): ], ), ], - fbobjc_platform_compiler_flags = [ - ( - ".*arm64.*", - ["-march=armv8+bf16"], - ), - ], fbobjc_exported_preprocessor_flags = [ "-DET_BUILD_WITH_BLAS", "-DET_BUILD_FOR_APPLE",