Skip to content

[libc][math][c23] Add MPFR unit test for f16sqrtf #97062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2024

Conversation

overmighty
Copy link
Member

No description provided.

@overmighty
Copy link
Member Author

Requires #96642.

@overmighty overmighty marked this pull request as ready for review July 1, 2024 10:24
@llvmbot llvmbot added the libc label Jul 1, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2024

@llvm/pr-subscribers-libc

Author: OverMighty (overmighty)

Changes

Patch is 28.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97062.diff

24 Files Affected:

  • (modified) libc/config/linux/aarch64/entrypoints.txt (+9)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+3)
  • (modified) libc/docs/math/index.rst (+1-1)
  • (modified) libc/spec/stdc.td (+3)
  • (modified) libc/src/__support/FPUtil/generic/CMakeLists.txt (+1)
  • (modified) libc/src/__support/FPUtil/generic/sqrt.h (+14-88)
  • (modified) libc/src/math/CMakeLists.txt (+3)
  • (added) libc/src/math/f16sqrt.h (+20)
  • (added) libc/src/math/f16sqrtf128.h (+20)
  • (added) libc/src/math/f16sqrtl.h (+20)
  • (modified) libc/src/math/generic/CMakeLists.txt (+39)
  • (added) libc/src/math/generic/f16sqrt.cpp (+19)
  • (added) libc/src/math/generic/f16sqrtf128.cpp (+19)
  • (added) libc/src/math/generic/f16sqrtl.cpp (+19)
  • (modified) libc/test/src/math/CMakeLists.txt (+51-6)
  • (modified) libc/test/src/math/SqrtTest.h (+16-27)
  • (added) libc/test/src/math/f16sqrt_test.cpp (+13)
  • (added) libc/test/src/math/f16sqrtf_test.cpp (+13)
  • (added) libc/test/src/math/f16sqrtl_test.cpp (+13)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+36)
  • (added) libc/test/src/math/smoke/f16sqrt_test.cpp (+13)
  • (added) libc/test/src/math/smoke/f16sqrtf128_test.cpp (+13)
  • (added) libc/test/src/math/smoke/f16sqrtl_test.cpp (+13)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+11)
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index fbf8c4b5a7581..7774c956cbc75 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -509,7 +509,9 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.copysignf16
     libc.src.math.f16divf
     libc.src.math.f16fmaf
+    libc.src.math.f16sqrt
     libc.src.math.f16sqrtf
+    libc.src.math.f16sqrtl
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
@@ -560,6 +562,13 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.ufromfpf16
     libc.src.math.ufromfpxf16
   )
+
+  if(LIBC_TYPES_HAS_FLOAT128)
+    list(APPEND TARGET_LIBM_ENTRYPOINTS
+      # math.h C23 mixed _Float16 and _Float128 entrypoints
+      libc.src.math.f16sqrtf128
+    )
+  endif()
 endif()
 
 if(LIBC_TYPES_HAS_FLOAT128)
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 9581f7f2604c4..cedcb423388b5 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -541,7 +541,9 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.f16fma
     libc.src.math.f16fmaf
     libc.src.math.f16fmal
+    libc.src.math.f16sqrt
     libc.src.math.f16sqrtf
+    libc.src.math.f16sqrtl
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
@@ -595,6 +597,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     list(APPEND TARGET_LIBM_ENTRYPOINTS
       # math.h C23 mixed _Float16 and _Float128 entrypoints
       libc.src.math.f16fmaf128
+      libc.src.math.f16sqrtf128
     )
   endif()
 endif()
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index 56cc8d658257d..816eb54699f9d 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -292,7 +292,7 @@ Higher Math Functions
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fma       | |check|          | |check|         |                        |                      |                        | 7.12.13.1              | F.10.10.1                  |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
-| f16sqrt   | |check|          |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
+| f16sqrt   | |check|          | |check|         | |check|                | N/A                  | |check|                | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index adac7d5932428..62faef450962c 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -731,7 +731,10 @@ def StdC : StandardSpec<"stdc"> {
 
           GuardedFunctionSpec<"f16divf", RetValSpec<Float16Type>, [ArgSpec<FloatType>, ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
 
+          GuardedFunctionSpec<"f16sqrt", RetValSpec<Float16Type>, [ArgSpec<DoubleType>], "LIBC_TYPES_HAS_FLOAT16">,
           GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
+          GuardedFunctionSpec<"f16sqrtl", RetValSpec<Float16Type>, [ArgSpec<LongDoubleType>], "LIBC_TYPES_HAS_FLOAT16">,
+          GuardedFunctionSpec<"f16sqrtf128", RetValSpec<Float16Type>, [ArgSpec<Float128Type>], "LIBC_TYPES_HAS_FLOAT16_AND_FLOAT128">,
       ]
   >;
 
diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
index bd8af98473edf..fb49fd0039537 100644
--- a/libc/src/__support/FPUtil/generic/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt
@@ -8,6 +8,7 @@ add_header_library(
     libc.src.__support.common
     libc.src.__support.CPP.bit
     libc.src.__support.CPP.type_traits
+    libc.src.__support.FPUtil.dyadic_float
     libc.src.__support.FPUtil.fenv_impl
     libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.rounding_mode
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index d6e894fdfe021..a1f81b0d25da6 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -14,7 +14,7 @@
 #include "src/__support/CPP/type_traits.h"
 #include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
-#include "src/__support/FPUtil/rounding_mode.h"
+#include "src/__support/FPUtil/dyadic_float.h"
 #include "src/__support/common.h"
 #include "src/__support/uint128.h"
 
@@ -78,16 +78,14 @@ sqrt(InType x) {
     return x86::sqrt(x);
   } else {
     // IEEE floating points formats.
-    using OutFPBits = typename fputil::FPBits<OutType>;
-    using OutStorageType = typename OutFPBits::StorageType;
-    using InFPBits = typename fputil::FPBits<InType>;
+    using OutFPBits = FPBits<OutType>;
+    using InFPBits = FPBits<InType>;
     using InStorageType = typename InFPBits::StorageType;
+    using DyadicFloat =
+        DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>;
+
     constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
     constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
-    constexpr int EXTRA_FRACTION_LEN =
-        InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
-    constexpr InStorageType EXTRA_FRACTION_MASK =
-        (InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
 
     InFPBits bits(x);
 
@@ -146,91 +144,19 @@ sqrt(InType x) {
       }
 
       // We compute one more iteration in order to round correctly.
-      bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
-                 0;    // Least significant bit
-      bool rb = false; // Round bit
       r <<= 2;
-      InStorageType tmp = (y << 2) + 1;
+      y <<= 2;
+      InStorageType tmp = y + 1;
       if (r >= tmp) {
         r -= tmp;
-        rb = true;
-      }
-
-      bool sticky = false;
-
-      if constexpr (EXTRA_FRACTION_LEN > 0) {
-        sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
-        rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
-      }
-
-      // Remove hidden bit and append the exponent field.
-      x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
-
-      OutStorageType y_out = static_cast<OutStorageType>(
-          ((y - ONE) >> EXTRA_FRACTION_LEN) |
-          (static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
-
-      if constexpr (EXTRA_FRACTION_LEN > 0) {
-        if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
-          switch (quick_get_round()) {
-          case FE_TONEAREST:
-          case FE_UPWARD:
-            return OutFPBits::inf().get_val();
-          default:
-            return OutFPBits::max_normal().get_val();
-          }
-        }
-
-        if (x_exp <
-            -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
-          switch (quick_get_round()) {
-          case FE_UPWARD:
-            return OutFPBits::min_subnormal().get_val();
-          default:
-            return OutType(0.0);
-          }
-        }
-
-        if (x_exp <= 0) {
-          int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
-          InStorageType underflow_extra_fraction_mask =
-              (InStorageType(1) << underflow_extra_fraction_len) - 1;
-
-          rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
-               0;
-          OutStorageType subnormal_mant =
-              static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
-          lsb = (subnormal_mant & 1) != 0;
-          sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
-
-          switch (quick_get_round()) {
-          case FE_TONEAREST:
-            if (rb && (lsb || sticky))
-              ++subnormal_mant;
-            break;
-          case FE_UPWARD:
-            if (rb || sticky)
-              ++subnormal_mant;
-            break;
-          }
-
-          return cpp::bit_cast<OutType>(subnormal_mant);
-        }
-      }
-
-      switch (quick_get_round()) {
-      case FE_TONEAREST:
-        // Round to nearest, ties to even
-        if (rb && (lsb || (r != 0)))
-          ++y_out;
-        break;
-      case FE_UPWARD:
-        if (rb || (r != 0) || sticky)
-          ++y_out;
-        break;
+        // Rounding bit.
+        y |= 1 << 1;
       }
+      // Sticky bit.
+      y |= static_cast<unsigned int>(r != 0);
 
-      return cpp::bit_cast<OutType>(y_out);
+      DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
+      return yd.template as<OutType, /*ShouldSignalExceptions=*/true>();
     }
   }
 }
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 3dfc4ac94827d..1715cf6275b29 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -106,7 +106,10 @@ add_math_entrypoint_object(f16fmaf)
 add_math_entrypoint_object(f16fmal)
 add_math_entrypoint_object(f16fmaf128)
 
+add_math_entrypoint_object(f16sqrt)
 add_math_entrypoint_object(f16sqrtf)
+add_math_entrypoint_object(f16sqrtl)
+add_math_entrypoint_object(f16sqrtf128)
 
 add_math_entrypoint_object(fabs)
 add_math_entrypoint_object(fabsf)
diff --git a/libc/src/math/f16sqrt.h b/libc/src/math/f16sqrt.h
new file mode 100644
index 0000000000000..f1134ac5ee2b9
--- /dev/null
+++ b/libc/src/math/f16sqrt.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrt -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRT_H
+#define LLVM_LIBC_SRC_MATH_F16SQRT_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrt(double x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRT_H
diff --git a/libc/src/math/f16sqrtf128.h b/libc/src/math/f16sqrtf128.h
new file mode 100644
index 0000000000000..61a6ce9ea5a5d
--- /dev/null
+++ b/libc/src/math/f16sqrtf128.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtf128 -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF128_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTF128_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtf128(float128 x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTF128_H
diff --git a/libc/src/math/f16sqrtl.h b/libc/src/math/f16sqrtl.h
new file mode 100644
index 0000000000000..fd3c55fc95f32
--- /dev/null
+++ b/libc/src/math/f16sqrtl.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtl ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTL_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTL_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtl(long double x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTL_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 3773a2b49c416..eb67447d76bc2 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3841,6 +3841,19 @@ add_entrypoint_object(
     -O3
 )
 
+add_entrypoint_object(
+  f16sqrt
+  SRCS
+    f16sqrt.cpp
+  HDRS
+    ../f16sqrt.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
+
 add_entrypoint_object(
   f16sqrtf
   SRCS
@@ -3853,3 +3866,29 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O3
 )
+
+add_entrypoint_object(
+  f16sqrtl
+  SRCS
+    f16sqrtl.cpp
+  HDRS
+    ../f16sqrtl.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
+
+add_entrypoint_object(
+  f16sqrtf128
+  SRCS
+    f16sqrtf128.cpp
+  HDRS
+    ../f16sqrtf128.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
diff --git a/libc/src/math/generic/f16sqrt.cpp b/libc/src/math/generic/f16sqrt.cpp
new file mode 100644
index 0000000000000..9d5f081f63315
--- /dev/null
+++ b/libc/src/math/generic/f16sqrt.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrt function --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrt.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrt, (double x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/f16sqrtf128.cpp b/libc/src/math/generic/f16sqrtf128.cpp
new file mode 100644
index 0000000000000..11a1e8252788e
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtf128.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtf128 function ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtf128.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtf128, (float128 x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/f16sqrtl.cpp b/libc/src/math/generic/f16sqrtl.cpp
new file mode 100644
index 0000000000000..2aaac9a780f66
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtl.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtl function -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtl.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtl, (long double x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index 36d2a2fbfd302..0a04624cc1b0f 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1259,9 +1259,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1271,9 +1272,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1283,9 +1285,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1295,9 +1298,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1310,9 +1314,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1325,9 +1330,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1960,6 +1966,45 @@ add_fp_unittest(
     libc.src.stdlib.srand
 )
 
+add_fp_unittest(
+  f16sqrt_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    f16sqrt_test.cpp
+  HDRS
+    SqrtTest.h
+  DEPENDS
+    libc.src.math.f16sqrt
+)
+
+add_fp_unittest(
+  f16sqrtf_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    f16sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
+  DEPENDS
+    libc.src.math.f16sqrtf
+)
+
+add_fp_unittest(
+  f16sqrtl_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    f16sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
+  DEPENDS
+    libc.src.math.f16sqrtl
+)
+
 add_subdirectory(generic)
 add_subdirectory(smoke)
 
diff --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h
index 1c422e201bb23..770cc94b3b940 100644
--- a/libc/test/src/math/SqrtTest.h
+++ b/libc/test/src/math/SqrtTest.h
@@ -6,51 +6,36 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "src/__support/CPP/bit.h"
 #include "test/UnitTest/FEnvSafeTest.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
 
-#include "hdr/math_macros.h"
-
 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 
-template <typename T>
+template <typename OutType, typename InType>
 class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
 
-  DECLARE_SPECIAL_CONSTANTS(T)
+  DECLARE_SPECIAL_CONSTANTS(InType)
 
   static constexpr StorageType HIDDEN_BIT =
-      StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<T>::FRACTION_LEN;
+      StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<InType>::FRACTION_LEN;
 
 public:
-  typedef T (*SqrtFunc)(T);
-
-  void test_special_numbers(SqrtFunc func) {
-    ASSERT_FP_EQ(aNaN, func(aNaN));
-    ASSERT_FP_EQ(inf, func(inf));
-    ASSERT_FP_EQ(aNaN, func(neg_inf));
-    ASSERT_FP_EQ(0.0, func(0.0));
-    ASSERT_FP_EQ(-0.0, func(-0.0));
-    ASSERT_FP_EQ(aNaN, func(T(-1.0)));
-    ASSERT_FP_EQ(T(1.0), func(T(1.0)));
-    ASSERT_FP_EQ(T(2.0), func(T(4.0)));
-    ASSERT_FP_EQ(T(3.0), func(T(9.0)));
-  }
+  using SqrtFunc = OutType (*)(InType);
 
   void test_denormal_values(SqrtFunc func) {
     for (StorageType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
-      FPBits denormal(T(0.0));
+      FPBits denormal(zero);
       denormal.set_mantissa(mant);
-      T x = denormal.get_val();
+      InType x = denormal.get_val();
       EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Sqrt, x, func(x), 0.5);
     }
 
     constexpr StorageType COUNT = 200'001;
     constexpr StorageType STEP = HIDDEN_BIT / COUNT;
     for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
-      T x = LIBC_NAMESPACE::cpp::bit_cast<T>(v);
+      InType x = FPBits(i).get_val();
       EXPECT_MPFR_MATCH_ALL_ROUND...
[truncated]

@overmighty overmighty force-pushed the libc-math-f16sqrtf-mpfr-test branch from d23f856 to ef38799 Compare July 1, 2024 10:27
@overmighty
Copy link
Member Author

Rebased.

@lntue lntue self-requested a review July 1, 2024 12:45
@lntue lntue merged commit a3f700a into llvm:main Jul 1, 2024
6 checks passed
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants