Skip to content

[libc][math][c23] Add f16sqrtf C23 math function #95251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 13, 2024

Conversation

overmighty
Copy link
Member

Part of #95250.

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-libc

Author: OverMighty (overmighty)

Changes

Part of #95250.


Patch is 31.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95251.diff

27 Files Affected:

  • (modified) libc/config/linux/aarch64/entrypoints.txt (+1)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+1)
  • (modified) libc/docs/math/index.rst (+2)
  • (modified) libc/spec/stdc.td (+2)
  • (modified) libc/src/__support/FPUtil/generic/sqrt.h (+112-29)
  • (modified) libc/src/__support/FPUtil/sqrt.h (+3-1)
  • (modified) libc/src/math/CMakeLists.txt (+2)
  • (added) libc/src/math/f16sqrtf.h (+20)
  • (modified) libc/src/math/generic/CMakeLists.txt (+12)
  • (added) libc/src/math/generic/f16sqrtf.cpp (+19)
  • (modified) libc/src/math/generic/sqrt.cpp (+1-1)
  • (modified) libc/src/math/generic/sqrtf.cpp (+1-1)
  • (modified) libc/src/math/generic/sqrtf128.cpp (+3-1)
  • (modified) libc/src/math/generic/sqrtl.cpp (+1-1)
  • (modified) libc/test/src/math/exhaustive/CMakeLists.txt (+15)
  • (modified) libc/test/src/math/exhaustive/exhaustive_test.h (+40)
  • (added) libc/test/src/math/exhaustive/f16sqrtf_test.cpp (+25)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+11)
  • (modified) libc/test/src/math/smoke/SqrtTest.h (+11-17)
  • (added) libc/test/src/math/smoke/f16sqrtf_test.cpp (+13)
  • (modified) libc/test/src/math/smoke/sqrt_test.cpp (+1-1)
  • (modified) libc/test/src/math/smoke/sqrtf128_test.cpp (+1-1)
  • (modified) libc/test/src/math/smoke/sqrtf_test.cpp (+1-1)
  • (modified) libc/test/src/math/smoke/sqrtl_test.cpp (+1-1)
  • (modified) libc/utils/MPFRWrapper/CMakeLists.txt (+2)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+49)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+21-1)
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index db96a80051a8d..2b2d0985a8992 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 355eaf33ace6d..2d36ca296c3a4 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.canonicalizef16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
+    libc.src.math.f16sqrtf
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index d556885eda622..8243b14ff4786 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -282,6 +282,8 @@ Higher Math Functions
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
+| f16sqrt   | |check|          |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
++-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | hypot     | |check|          | |check|         |                        |                      |                        | 7.12.7.4               | F.10.4.4                   |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | lgamma    |                  |                 |                        |                      |                        | 7.12.8.3               | F.10.5.3                   |
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index b134ec00a7d7a..7c4135032a0b2 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
           GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
 
           GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
+
+          GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
       ]
   >;
 
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index 7e7600ba6502a..4c95053217228 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -17,6 +17,7 @@
 #include "src/__support/FPUtil/rounding_mode.h"
 #include "src/__support/common.h"
 #include "src/__support/uint128.h"
+#include <fenv.h>
 
 namespace LIBC_NAMESPACE {
 namespace fputil {
@@ -64,40 +65,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
 
 // Correctly rounded IEEE 754 SQRT for all rounding modes.
 // Shift-and-add algorithm.
-template <typename T>
-LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
-
-  if constexpr (internal::SpecialLongDouble<T>::VALUE) {
+template <typename OutType, typename InType>
+LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
+                                 cpp::is_floating_point_v<InType> &&
+                                 sizeof(OutType) <= sizeof(InType),
+                             OutType>
+sqrt(InType x) {
+  if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
+                internal::SpecialLongDouble<InType>::VALUE) {
     // Special 80-bit long double.
     return x86::sqrt(x);
   } else {
     // IEEE floating points formats.
-    using FPBits_t = typename fputil::FPBits<T>;
-    using StorageType = typename FPBits_t::StorageType;
-    constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
-    constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
-
-    FPBits_t bits(x);
-
-    if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
+    using OutFPBits = typename fputil::FPBits<OutType>;
+    using OutStorageType = typename OutFPBits::StorageType;
+    using InFPBits = typename fputil::FPBits<InType>;
+    using InStorageType = typename InFPBits::StorageType;
+    constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
+    constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
+    constexpr int EXTRA_FRACTION_LEN =
+        InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
+    constexpr InStorageType EXTRA_FRACTION_MASK =
+        (InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
+
+    InFPBits bits(x);
+
+    if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
       // sqrt(+Inf) = +Inf
       // sqrt(+0) = +0
       // sqrt(-0) = -0
       // sqrt(NaN) = NaN
       // sqrt(-NaN) = -NaN
-      return x;
+      return static_cast<OutType>(x);
     } else if (bits.is_neg()) {
       // sqrt(-Inf) = NaN
       // sqrt(-x) = NaN
       return FLT_NAN;
     } else {
       int x_exp = bits.get_exponent();
-      StorageType x_mant = bits.get_mantissa();
+      InStorageType x_mant = bits.get_mantissa();
 
       // Step 1a: Normalize denormal input and append hidden bit to the mantissa
       if (bits.is_subnormal()) {
         ++x_exp; // let x_exp be the correct exponent of ONE bit.
-        internal::normalize<T>(x_exp, x_mant);
+        internal::normalize<InType>(x_exp, x_mant);
       } else {
         x_mant |= ONE;
       }
@@ -120,12 +131,13 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
       // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
       //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
       //         0 otherwise.
-      StorageType y = ONE;
-      StorageType r = x_mant - ONE;
+      InStorageType y = ONE;
+      InStorageType r = x_mant - ONE;
 
-      for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
+      for (InStorageType current_bit = ONE >> 1; current_bit;
+           current_bit >>= 1) {
         r <<= 1;
-        StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
+        InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
         if (r >= tmp) {
           r -= tmp;
           y += current_bit;
@@ -133,34 +145,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
       }
 
       // We compute one more iteration in order to round correctly.
-      bool lsb = static_cast<bool>(y & 1); // Least significant bit
-      bool rb = false;                     // Round bit
+      bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
+                 0;    // Least significant bit
+      bool rb = false; // Round bit
       r <<= 2;
-      StorageType tmp = (y << 2) + 1;
+      InStorageType tmp = (y << 2) + 1;
       if (r >= tmp) {
         r -= tmp;
         rb = true;
       }
 
+      bool sticky = false;
+
+      if constexpr (EXTRA_FRACTION_LEN > 0) {
+        sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
+        rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
+      }
+
       // Remove hidden bit and append the exponent field.
-      x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
+      x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
+
+      OutStorageType y_out = static_cast<OutStorageType>(
+          ((y - ONE) >> EXTRA_FRACTION_LEN) |
+          (static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
+
+      if constexpr (EXTRA_FRACTION_LEN > 0) {
+        if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+          case FE_UPWARD:
+            return OutFPBits::inf().get_val();
+          default:
+            return OutFPBits::max_normal().get_val();
+          }
+        }
+
+        if (x_exp == OutFPBits::MAX_BIASED_EXPONENT - 1 &&
+            y == OutFPBits::max_normal().uintval() && (rb || sticky)) {
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+            if (rb)
+              return OutFPBits::inf().get_val();
+            return OutFPBits::max_normal().get_val();
+          case FE_UPWARD:
+            return OutFPBits::inf().get_val();
+          default:
+            return OutFPBits::max_normal().get_val();
+          }
+        }
 
-      y = (y - ONE) |
-          (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
+        if (x_exp <
+            -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
+          switch (quick_get_round()) {
+          case FE_UPWARD:
+            return OutFPBits::min_subnormal().get_val();
+          default:
+            return OutType(0.0);
+          }
+        }
+
+        if (x_exp <= 0) {
+          int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
+          InStorageType underflow_extra_fraction_mask =
+              (InStorageType(1) << underflow_extra_fraction_len) - 1;
+
+          rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
+               0;
+          OutStorageType subnormal_mant =
+              static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
+          lsb = (subnormal_mant & 1) != 0;
+          sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
+
+          switch (quick_get_round()) {
+          case FE_TONEAREST:
+            if (rb && (lsb || sticky))
+              ++subnormal_mant;
+            break;
+          case FE_UPWARD:
+            if (rb || sticky)
+              ++subnormal_mant;
+            break;
+          }
+
+          return cpp::bit_cast<OutType>(subnormal_mant);
+        }
+      }
 
       switch (quick_get_round()) {
       case FE_TONEAREST:
         // Round to nearest, ties to even
         if (rb && (lsb || (r != 0)))
-          ++y;
+          ++y_out;
         break;
       case FE_UPWARD:
-        if (rb || (r != 0))
-          ++y;
+        if (rb || (r != 0) || sticky)
+          ++y_out;
         break;
       }
 
-      return cpp::bit_cast<T>(y);
+      return cpp::bit_cast<OutType>(y_out);
     }
   }
 }
diff --git a/libc/src/__support/FPUtil/sqrt.h b/libc/src/__support/FPUtil/sqrt.h
index eb86ddfa89d8e..d9c30c586bb0d 100644
--- a/libc/src/__support/FPUtil/sqrt.h
+++ b/libc/src/__support/FPUtil/sqrt.h
@@ -13,7 +13,9 @@
 #include "src/__support/macros/properties/cpu_features.h"
 
 #if defined(LIBC_TARGET_ARCH_IS_X86_64) && defined(LIBC_TARGET_CPU_HAS_SSE2)
-#include "x86_64/sqrt.h"
+// #include "x86_64/sqrt.h"
+// TODO
+#include "generic/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_AARCH64)
 #include "aarch64/sqrt.h"
 #elif defined(LIBC_TARGET_ARCH_IS_ANY_RISCV)
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 2446c293b8ef5..df8e6c0b253da 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
 add_math_entrypoint_object(expm1)
 add_math_entrypoint_object(expm1f)
 
+add_math_entrypoint_object(f16sqrtf)
+
 add_math_entrypoint_object(fabs)
 add_math_entrypoint_object(fabsf)
 add_math_entrypoint_object(fabsl)
diff --git a/libc/src/math/f16sqrtf.h b/libc/src/math/f16sqrtf.h
new file mode 100644
index 0000000000000..197ebe6db8016
--- /dev/null
+++ b/libc/src/math/f16sqrtf.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTF_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtf(float x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 673bef516b13d..45a28723ba6b0 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3601,3 +3601,15 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O3
 )
+
+add_entrypoint_object(
+  f16sqrtf
+  SRCS
+    f16sqrtf.cpp
+  HDRS
+    ../f16sqrtf.h
+  DEPENDS
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
diff --git a/libc/src/math/generic/f16sqrtf.cpp b/libc/src/math/generic/f16sqrtf.cpp
new file mode 100644
index 0000000000000..1f7ee2df29e86
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtf.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtf function -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtf.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrt.cpp b/libc/src/math/generic/sqrt.cpp
index b4d02785dcb43..f33b0a2cdcf74 100644
--- a/libc/src/math/generic/sqrt.cpp
+++ b/libc/src/math/generic/sqrt.cpp
@@ -12,6 +12,6 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf.cpp b/libc/src/math/generic/sqrtf.cpp
index bc74252295b3a..26a53e9077c1c 100644
--- a/libc/src/math/generic/sqrtf.cpp
+++ b/libc/src/math/generic/sqrtf.cpp
@@ -12,6 +12,6 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtf128.cpp b/libc/src/math/generic/sqrtf128.cpp
index 0196c3e0a96ae..70e28ddb692d4 100644
--- a/libc/src/math/generic/sqrtf128.cpp
+++ b/libc/src/math/generic/sqrtf128.cpp
@@ -12,6 +12,8 @@
 
 namespace LIBC_NAMESPACE {
 
-LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
+LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
+  return fputil::sqrt<float128>(x);
+}
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/sqrtl.cpp b/libc/src/math/generic/sqrtl.cpp
index b2aaa279f9c2a..9f0cc87853823 100644
--- a/libc/src/math/generic/sqrtl.cpp
+++ b/libc/src/math/generic/sqrtl.cpp
@@ -13,7 +13,7 @@
 namespace LIBC_NAMESPACE {
 
 LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
-  return fputil::sqrt(x);
+  return fputil::sqrt<long double>(x);
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt
index 938e519aff084..34df8720ed4db 100644
--- a/libc/test/src/math/exhaustive/CMakeLists.txt
+++ b/libc/test/src/math/exhaustive/CMakeLists.txt
@@ -420,3 +420,18 @@ add_fp_unittest(
   LINK_LIBRARIES
     -lpthread
 )
+
+add_fp_unittest(
+  f16sqrtf_test
+  NO_RUN_POSTBUILD
+  NEED_MPFR
+  SUITE
+    libc_math_exhaustive_tests
+  SRCS
+    f16sqrtf_test.cpp
+  DEPENDS
+    .exhaustive_test
+    libc.src.math.f16sqrtf
+  LINK_LIBRARIES
+    -lpthread
+)
diff --git a/libc/test/src/math/exhaustive/exhaustive_test.h b/libc/test/src/math/exhaustive/exhaustive_test.h
index c4ae382688a03..1f8daf497ab2f 100644
--- a/libc/test/src/math/exhaustive/exhaustive_test.h
+++ b/libc/test/src/math/exhaustive/exhaustive_test.h
@@ -68,6 +68,41 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
   }
 };
 
+template <typename OutType, typename InType>
+using UnaryNarrowerOp = OutType(InType);
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+          UnaryNarrowerOp<OutType, InType> Func>
+struct UnaryNarrowerOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
+  using FloatType = InType;
+  using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
+  using StorageType = typename FPBits::StorageType;
+
+  static constexpr UnaryNarrowerOp<OutType, FloatType> *FUNC = Func;
+
+  // Check in a range, return the number of failures.
+  uint64_t check(StorageType start, StorageType stop,
+                 mpfr::RoundingMode rounding) {
+    mpfr::ForceRoundingMode r(rounding);
+    if (!r.success)
+      return (stop > start);
+    StorageType bits = start;
+    uint64_t failed = 0;
+    do {
+      FPBits xbits(bits);
+      FloatType x = xbits.get_val();
+      bool correct =
+          TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, FUNC(x), 0.5, rounding);
+      failed += (!correct);
+      // Uncomment to print out failed values.
+      if (!correct) {
+        EXPECT_MPFR_MATCH_ROUNDING(Op, x, FUNC(x), 0.5, rounding);
+      }
+    } while (bits++ < stop);
+    return failed;
+  }
+};
+
 // Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
 //   StorageType and check method.
 template <typename Checker>
@@ -170,3 +205,8 @@ struct LlvmLibcExhaustiveMathTest
 template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
 using LlvmLibcUnaryOpExhaustiveMathTest =
     LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, Op, Func>>;
+
+template <typename OutType, typename InType, mpfr::Operation Op,
+          UnaryNarrowerOp<OutType, InType> Func>
+using LlvmLibcUnaryNarrowerOpExhaustiveMathTest = LlvmLibcExhaustiveMathTest<
+    UnaryNarrowerOpChecker<OutType, InType, Op, Func>>;
diff --git a/libc/test/src/math/exhaustive/f16sqrtf_test.cpp b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
new file mode 100644
index 0000000000000..5bc04f6bdc7cf
--- /dev/null
+++ b/libc/test/src/math/exhaustive/f16sqrtf_test.cpp
@@ -0,0 +1,25 @@
+//===-- Exhaustive test for f16sqrtf --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "exhaustive_test.h"
+#include "src/math/f16sqrtf.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+using LlvmLibcF16sqrtfExhaustiveTest =
+    LlvmLibcUnaryNarrowerOpExhaustiveMathTest<
+        float16, float, mpfr::Operation::Sqrt, LIBC_NAMESPACE::f16sqrtf>;
+
+// Range: [0, Inf];
+static constexpr uint32_t POS_START = 0x0000'0000U;
+static constexpr uint32_t POS_STOP = 0x7f80'0000U;
+
+TEST_F(LlvmLibcF16sqrtfExhaustiveTest, PostiveRange) {
+  test_full_range_all_roundings(POS_START, POS_STOP);
+}
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index 68cd412b14e9d..d67f5abd2ab1c 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -3543,3 +3543,14 @@ add_fp_unittest(
   DEPENDS
     libc.src.math.totalordermagf16
 )
+
+add_fp_unittest(
+  f16sqrtf_test
+  SUITE
+    libc-math-smoke-tests
+  SRCS
+    f16sqrtf_test.cpp
+  DEPENDS
+    libc.src.math.f16sqrtf
+    libc.src.__support.FPUtil.fp_bits
+)
diff --git a/libc/test/src/math/smoke/SqrtTest...
[truncated]

Comment on lines +783 to +784
if (msg.overflow())
__builtin_unreachable();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer size of 1024 should always be much more than enough.

Comment on lines 766 to 768
MPFRNumber mpfrInput(input, precision);
MPFRNumber mpfr_result;
mpfr_result = unary_operation(op, input, precision, rounding);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not written by me, but the naming/case isn't consistent, and the assignment could be merged with the declaration.

@lntue lntue merged commit a239343 into llvm:main Jun 13, 2024
5 of 6 checks passed
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants