Skip to content

[libc++][math] Fix undue overflowing of std::hypot(x,y,z) #100820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions libcxx/include/__math/hypot.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,25 @@
#ifndef _LIBCPP___MATH_HYPOT_H
#define _LIBCPP___MATH_HYPOT_H

#include <__algorithm/max.h>
#include <__config>
#include <__math/abs.h>
#include <__math/exponential_functions.h>
#include <__math/roots.h>
#include <__type_traits/enable_if.h>
#include <__type_traits/is_arithmetic.h>
#include <__type_traits/is_same.h>
#include <__type_traits/promote.h>
#include <__utility/pair.h>
#include <limits>

#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
#endif

_LIBCPP_PUSH_MACROS
#include <__undef_macros>

_LIBCPP_BEGIN_NAMESPACE_STD

namespace __math {
Expand All @@ -41,8 +50,60 @@ inline _LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2>::type hypot(_A1 __x, _
return __math::hypot((__result_type)__x, (__result_type)__y);
}

#if _LIBCPP_STD_VER >= 17
// Computes the three-dimensional hypotenuse: `std::hypot(x,y,z)`.
// The naive implementation might over-/underflow which is why this implementation is more involved:
// If the square of an argument might run into issues, we scale the arguments appropriately.
// See https://github.com/llvm/llvm-project/issues/92782 for a detailed discussion and summary.
template <class _Real>
_LIBCPP_HIDE_FROM_ABI _Real __hypot(_Real __x, _Real __y, _Real __z) {
// Factors needed to determine if over-/underflow might happen
constexpr int __exp = std::numeric_limits<_Real>::max_exponent / 2;
const _Real __overflow_threshold = __math::ldexp(_Real(1), __exp);
const _Real __overflow_scale = __math::ldexp(_Real(1), -(__exp + 20));

// Scale arguments depending on their size
const _Real __max_abs = std::max(__math::fabs(__x), std::max(__math::fabs(__y), __math::fabs(__z)));
_Real __scale;
if (__max_abs > __overflow_threshold) { // x*x + y*y + z*z might overflow
__scale = __overflow_scale;
} else if (__max_abs < 1 / __overflow_threshold) { // x*x + y*y + z*z might underflow
__scale = 1 / __overflow_scale;
} else {
__scale = 1;
}
__x *= __scale;
__y *= __scale;
__z *= __scale;

// Compute hypot of scaled arguments and undo scaling
return __math::sqrt(__x * __x + __y * __y + __z * __z) / __scale;
}

inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) { return __math::__hypot(__x, __y, __z); }

inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) { return __math::__hypot(__x, __y, __z); }

inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) {
return __math::__hypot(__x, __y, __z);
}

template <class _A1,
class _A2,
class _A3,
std::enable_if_t< is_arithmetic_v<_A1> && is_arithmetic_v<_A2> && is_arithmetic_v<_A3>, int> = 0 >
_LIBCPP_HIDE_FROM_ABI typename __promote<_A1, _A2, _A3>::type hypot(_A1 __x, _A2 __y, _A3 __z) _NOEXCEPT {
using __result_type = typename __promote<_A1, _A2, _A3>::type;
static_assert(!(
std::is_same_v<_A1, __result_type> && std::is_same_v<_A2, __result_type> && std::is_same_v<_A3, __result_type>));
return __math::__hypot(
static_cast<__result_type>(__x), static_cast<__result_type>(__y), static_cast<__result_type>(__z));
}
#endif

} // namespace __math

_LIBCPP_END_NAMESPACE_STD
_LIBCPP_POP_MACROS

#endif // _LIBCPP___MATH_HYPOT_H
25 changes: 1 addition & 24 deletions libcxx/include/cmath
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ constexpr long double lerp(long double a, long double b, long double t) noexcept
*/

#include <__config>
#include <__math/hypot.h>
#include <__type_traits/enable_if.h>
#include <__type_traits/is_arithmetic.h>
#include <__type_traits/is_constant_evaluated.h>
Expand Down Expand Up @@ -553,30 +554,6 @@ using ::scalbnl _LIBCPP_USING_IF_EXISTS;
using ::tgammal _LIBCPP_USING_IF_EXISTS;
using ::truncl _LIBCPP_USING_IF_EXISTS;

#if _LIBCPP_STD_VER >= 17
inline _LIBCPP_HIDE_FROM_ABI float hypot(float __x, float __y, float __z) {
return sqrt(__x * __x + __y * __y + __z * __z);
}
inline _LIBCPP_HIDE_FROM_ABI double hypot(double __x, double __y, double __z) {
return sqrt(__x * __x + __y * __y + __z * __z);
}
inline _LIBCPP_HIDE_FROM_ABI long double hypot(long double __x, long double __y, long double __z) {
return sqrt(__x * __x + __y * __y + __z * __z);
}

template <class _A1, class _A2, class _A3>
inline _LIBCPP_HIDE_FROM_ABI
typename enable_if_t< is_arithmetic<_A1>::value && is_arithmetic<_A2>::value && is_arithmetic<_A3>::value,
__promote<_A1, _A2, _A3> >::type
hypot(_A1 __lcpp_x, _A2 __lcpp_y, _A3 __lcpp_z) _NOEXCEPT {
typedef typename __promote<_A1, _A2, _A3>::type __result_type;
static_assert(
!(is_same<_A1, __result_type>::value && is_same<_A2, __result_type>::value && is_same<_A3, __result_type>::value),
"");
return std::hypot((__result_type)__lcpp_x, (__result_type)__lcpp_y, (__result_type)__lcpp_z);
}
#endif

template <class _A1, __enable_if_t<is_floating_point<_A1>::value, int> = 0>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR bool __constexpr_isnan(_A1 __lcpp_x) _NOEXCEPT {
#if __has_builtin(__builtin_isnan)
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx03.csv
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ chrono type_traits
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath type_traits
cmath version
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx11.csv
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ chrono type_traits
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath type_traits
cmath version
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx14.csv
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ chrono type_traits
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath type_traits
cmath version
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx17.csv
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ chrono type_traits
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to be missing some transitive includes additions in c++03, that's what is breaking the CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how I can run specifically that failing test locally?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libcxx/utils/libcxx-lit <BUILD DIR> -sv libcxx/test/libcxx/transitive_includes.gen.py --param std=c++03

cmath cstdint
cmath initializer_list
cmath limits
cmath type_traits
cmath version
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx20.csv
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ chrono type_traits
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath type_traits
cmath version
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx23.csv
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ chrono string_view
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath version
codecvt cctype
Expand Down
3 changes: 3 additions & 0 deletions libcxx/test/libcxx/transitive_includes/cxx26.csv
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ chrono string_view
chrono vector
chrono version
cinttypes cstdint
cmath cstddef
cmath cstdint
cmath initializer_list
cmath limits
cmath version
codecvt cctype
Expand Down
91 changes: 75 additions & 16 deletions libcxx/test/std/numerics/c.math/cmath.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@

// <cmath>

#include <array>
#include <cmath>
#include <limits>
#include <type_traits>
#include <cassert>

#include "fp_compare.h"
#include "test_macros.h"
#include "hexfloat.h"
#include "truncate_fp.h"
#include "type_algorithms.h"

// convertible to int/float/double/etc
template <class T, int N=0>
Expand Down Expand Up @@ -1113,6 +1116,56 @@ void test_fmin()
assert(std::fmin(1,0) == 0);
}

#if TEST_STD_VER >= 17
struct TestHypot3 {
template <class Real>
void operator()() const {
const auto check = [](Real elem, Real abs_tol) {
assert(std::isfinite(std::hypot(elem, Real(0), Real(0))));
assert(fptest_close(std::hypot(elem, Real(0), Real(0)), elem, abs_tol));
assert(std::isfinite(std::hypot(elem, elem, Real(0))));
assert(fptest_close(std::hypot(elem, elem, Real(0)), std::sqrt(Real(2)) * elem, abs_tol));
assert(std::isfinite(std::hypot(elem, elem, elem)));
assert(fptest_close(std::hypot(elem, elem, elem), std::sqrt(Real(3)) * elem, abs_tol));
};

{ // check for overflow
const auto [elem, abs_tol] = []() -> std::array<Real, 2> {
if constexpr (std::is_same_v<Real, float>)
return {1e20f, 1e16f};
else if constexpr (std::is_same_v<Real, double>)
return {1e300, 1e287};
else { // long double
# if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__
return {1e300l, 1e287l}; // 64-bit
# else
return {1e4000l, 1e3985l}; // 80- or 128-bit
# endif
}
}();
check(elem, abs_tol);
}

{ // check for underflow
const auto [elem, abs_tol] = []() -> std::array<Real, 2> {
if constexpr (std::is_same_v<Real, float>)
return {1e-20f, 1e-24f};
else if constexpr (std::is_same_v<Real, double>)
return {1e-287, 1e-300};
else { // long double
# if __DBL_MAX_EXP__ == __LDBL_MAX_EXP__
return {1e-287l, 1e-300l}; // 64-bit
# else
return {1e-3985l, 1e-4000l}; // 80- or 128-bit
# endif
}
}();
check(elem, abs_tol);
}
}
};
#endif

void test_hypot()
{
static_assert((std::is_same<decltype(std::hypot((float)0, (float)0)), float>::value), "");
Expand All @@ -1135,25 +1188,31 @@ void test_hypot()
static_assert((std::is_same<decltype(hypot(Ambiguous(), Ambiguous())), Ambiguous>::value), "");
assert(std::hypot(3,4) == 5);

#if TEST_STD_VER > 14
static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (float)0)), float>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (bool)0, (float)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (unsigned short)0, (double)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (long double)0)), long double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (long)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (long double)0, (unsigned long)0)), long double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (long long)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (int)0, (unsigned long long)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (double)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (long double)0, (long double)0)), long double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (double)0)), double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (float)0, (long double)0)), long double>::value), "");
static_assert((std::is_same<decltype(std::hypot((float)0, (double)0, (long double)0)), long double>::value), "");
static_assert((std::is_same<decltype(std::hypot((int)0, (int)0, (int)0)), double>::value), "");
static_assert((std::is_same<decltype(hypot(Ambiguous(), Ambiguous(), Ambiguous())), Ambiguous>::value), "");
#if TEST_STD_VER >= 17
// clang-format off
static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0, (float)0)), float>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (bool)0, (float)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (unsigned short)0, (double)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0, (long double)0)), long double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0, (long)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (long double)0, (unsigned long)0)), long double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0, (long long)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (int)0, (unsigned long long)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0, (double)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (long double)0, (long double)0)), long double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0, (double)0)), double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (float)0, (long double)0)), long double>));
static_assert((std::is_same_v<decltype(std::hypot((float)0, (double)0, (long double)0)), long double>));
static_assert((std::is_same_v<decltype(std::hypot((int)0, (int)0, (int)0)), double>));
static_assert((std::is_same_v<decltype(hypot(Ambiguous(), Ambiguous(), Ambiguous())), Ambiguous>));
// clang-format on

assert(std::hypot(2,3,6) == 7);
assert(std::hypot(1,4,8) == 9);

// Check for undue over-/underflows of intermediate results.
// See discussion at https://github.com/llvm/llvm-project/issues/92782.
types::for_each(types::floating_point_types(), TestHypot3());
#endif
}

Expand Down
45 changes: 20 additions & 25 deletions libcxx/test/support/fp_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,34 @@
#ifndef SUPPORT_FP_COMPARE_H
#define SUPPORT_FP_COMPARE_H

#include <cmath> // for std::abs
#include <algorithm> // for std::max
#include <cmath> // for std::abs
#include <algorithm> // for std::max
#include <cassert>
#include <__config>

// See https://www.boost.org/doc/libs/1_70_0/libs/test/doc/html/boost_test/testing_tools/extended_comparison/floating_point/floating_points_comparison_theory.html

template<typename T>
bool fptest_close(T val, T expected, T eps)
{
constexpr T zero = T(0);
assert(eps >= zero);
template <typename T>
bool fptest_close(T val, T expected, T eps) {
_LIBCPP_CONSTEXPR T zero = T(0);
assert(eps >= zero);

// Handle the zero cases
if (eps == zero) return val == expected;
if (val == zero) return std::abs(expected) <= eps;
if (expected == zero) return std::abs(val) <= eps;
// Handle the zero cases
if (eps == zero)
return val == expected;
if (val == zero)
return std::abs(expected) <= eps;
if (expected == zero)
return std::abs(val) <= eps;

return std::abs(val - expected) < eps
&& std::abs(val - expected)/std::abs(val) < eps;
return std::abs(val - expected) < eps && std::abs(val - expected) / std::abs(val) < eps;
}

template<typename T>
bool fptest_close_pct(T val, T expected, T percent)
{
constexpr T zero = T(0);
assert(percent >= zero);

// Handle the zero cases
if (percent == zero) return val == expected;
T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected));

return fptest_close(val, expected, eps);
template <typename T>
bool fptest_close_pct(T val, T expected, T percent) {
assert(percent >= T(0));
T eps = (percent / T(100)) * std::max(std::abs(val), std::abs(expected));
return fptest_close(val, expected, eps);
}


#endif // SUPPORT_FP_COMPARE_H
Loading