Skip to content

Commit 2ea8653

Browse files
malfetpytorchmergebot
authored andcommitted
[vec128] Fix fmsub NEON defintion (pytorch#152075)
As reported in pytorch#149292, according to manual, `vfmsq_f32` implements `c - a * b` rather than `a * b - c`, so it's call must be prefixed with `vnegq_f32` Also, adjust the tests to use OpMath for FMA computation to avoid accuracy error accumulation due to non-fused multiply-and-add over lower precision dtypes Note that `Vectorized::fmsub` is not currently instantiated anywhere, so it could safely remain broken TODO: - Enable C++ testing on MacOS and/or aarch64 platforms (right now Mac tests are build without C++ tests) Fixes pytorch#149292 Pull Request resolved: pytorch#152075 Approved by: https://github.com/swolchok ghstack dependencies: pytorch#151955
1 parent 5e9bdc9 commit 2ea8653

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<floa
540540

541541
template <>
542542
Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
543-
return Vectorized<float>(vfmsq_f32(c, a, b));
543+
return Vectorized<float>(vnegq_f32(vfmsq_f32(c, a, b)));
544544
}
545545

546546
inline Vectorized<float> Vectorized<float>::erf() const{

aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Vectorized<c10::Half> inline fmsub(
582582
const Vectorized<c10::Half>& b,
583583
const Vectorized<c10::Half>& c) {
584584
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
585-
return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
585+
return Vectorized<c10::Half>(vnegq_f16(vfmsq_f16(c, a, b)));
586586
#else
587587
return a * b - c;
588588
#endif

aten/src/ATen/test/vec_test_all_types.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ CACHE_ALIGN #define
6464
#undef CHECK_WITH_FMA
6565
#endif
6666

67+
template <typename scalar_t>
68+
struct OpMathType {
69+
using type = scalar_t;
70+
};
71+
template <>
72+
struct OpMathType<c10::Half> {
73+
using type = float;
74+
};
75+
76+
6777
template<typename T>
6878
using Complex = typename c10::complex<T>;
6979

@@ -1317,15 +1327,17 @@ std::enable_if_t<is_complex<Complex<T>>::value, Complex<T>> local_division(Compl
13171327
template <typename T>
13181328
std::enable_if_t<!is_complex<T>::value, T> local_fmadd(T a, T b, T c) {
13191329
PreventFma noFma;
1320-
T ab = a * b;
1321-
return noFma.add(ab, c);
1330+
using op_math_t = typename OpMathType<T>::type;
1331+
auto ab = static_cast<op_math_t>(a) * static_cast<op_math_t>(b);
1332+
return static_cast<T>(noFma.add(ab, op_math_t(c)));
13221333
}
13231334

13241335
template <typename T>
13251336
std::enable_if_t<!is_complex<T>::value, T> local_fmsub(T a, T b, T c) {
13261337
PreventFma noFma;
1327-
T ab = a * b;
1328-
return noFma.sub(ab, c);
1338+
using op_math_t = typename OpMathType<T>::type;
1339+
auto ab = static_cast<op_math_t>(a) * static_cast<op_math_t>(b);
1340+
return static_cast<T>(noFma.sub(ab, op_math_t(c)));
13291341
}
13301342

13311343
template <typename T>

0 commit comments

Comments
 (0)