Skip to content

[libc][math][c23] Add f16sqrt{,l,f128} C23 math functions #96642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 30, 2024

Conversation

overmighty
Copy link
Member

@overmighty overmighty commented Jun 25, 2024

Part of #95250.

@overmighty
Copy link
Member Author

cc @lntue

@llvmbot llvmbot added the libc label Jun 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2024

@llvm/pr-subscribers-libc

Author: OverMighty (overmighty)

Changes

Part of #95250.


Patch is 27.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96642.diff

23 Files Affected:

  • (modified) libc/config/linux/aarch64/entrypoints.txt (+9)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+9)
  • (modified) libc/docs/math/index.rst (+1-1)
  • (modified) libc/spec/stdc.td (+3)
  • (modified) libc/src/__support/FPUtil/generic/CMakeLists.txt (+1)
  • (modified) libc/src/__support/FPUtil/generic/sqrt.h (+14-88)
  • (modified) libc/src/math/CMakeLists.txt (+3)
  • (added) libc/src/math/f16sqrt.h (+20)
  • (added) libc/src/math/f16sqrtf128.h (+20)
  • (added) libc/src/math/f16sqrtl.h (+20)
  • (modified) libc/src/math/generic/CMakeLists.txt (+39)
  • (added) libc/src/math/generic/f16sqrt.cpp (+19)
  • (added) libc/src/math/generic/f16sqrtf128.cpp (+19)
  • (added) libc/src/math/generic/f16sqrtl.cpp (+19)
  • (modified) libc/test/src/math/CMakeLists.txt (+38-6)
  • (modified) libc/test/src/math/SqrtTest.h (+16-27)
  • (added) libc/test/src/math/f16sqrt_test.cpp (+13)
  • (added) libc/test/src/math/f16sqrtl_test.cpp (+13)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+36)
  • (added) libc/test/src/math/smoke/f16sqrt_test.cpp (+13)
  • (added) libc/test/src/math/smoke/f16sqrtf128_test.cpp (+13)
  • (added) libc/test/src/math/smoke/f16sqrtl_test.cpp (+13)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+11)
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index a875a17f06b3e..ce74653091b0d 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -507,7 +507,9 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.copysignf16
     libc.src.math.f16divf
     libc.src.math.f16fmaf
+    libc.src.math.f16sqrt
     libc.src.math.f16sqrtf
+    libc.src.math.f16sqrtl
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
@@ -558,6 +560,13 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.ufromfpf16
     libc.src.math.ufromfpxf16
   )
+
+  if(LIBC_TYPES_HAS_FLOAT128)
+    list(APPEND TARGET_LIBM_ENTRYPOINTS
+      # math.h C23 mixed _Float16 and _Float128 entrypoints
+      libc.src.math.f16sqrtf128
+    )
+  endif()
 endif()
 
 if(LIBC_TYPES_HAS_FLOAT128)
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 34748ff5950ad..62383676ef16a 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -538,7 +538,9 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.copysignf16
     libc.src.math.f16divf
     libc.src.math.f16fmaf
+    libc.src.math.f16sqrt
     libc.src.math.f16sqrtf
+    libc.src.math.f16sqrtl
     libc.src.math.fabsf16
     libc.src.math.fdimf16
     libc.src.math.floorf16
@@ -587,6 +589,13 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.ufromfpf16
     libc.src.math.ufromfpxf16
   )
+
+  if(LIBC_TYPES_HAS_FLOAT128)
+    list(APPEND TARGET_LIBM_ENTRYPOINTS
+      # math.h C23 mixed _Float16 and _Float128 entrypoints
+      libc.src.math.f16sqrtf128
+    )
+  endif()
 endif()
 
 if(LIBC_TYPES_HAS_FLOAT128)
diff --git a/libc/docs/math/index.rst b/libc/docs/math/index.rst
index 95f450ab75960..a96ec59cfcfd9 100644
--- a/libc/docs/math/index.rst
+++ b/libc/docs/math/index.rst
@@ -292,7 +292,7 @@ Higher Math Functions
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fma       | |check|          | |check|         |                        |                      |                        | 7.12.13.1              | F.10.10.1                  |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
-| f16sqrt   | |check|          |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
+| f16sqrt   | |check|          | |check|         | |check|                | N/A                  | |check|                | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
 | fsqrt     | N/A              |                 |                        | N/A                  |                        | 7.12.14.6              | F.10.11                    |
 +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
diff --git a/libc/spec/stdc.td b/libc/spec/stdc.td
index 651f49deef4c1..ba9d7716cf358 100644
--- a/libc/spec/stdc.td
+++ b/libc/spec/stdc.td
@@ -728,7 +728,10 @@ def StdC : StandardSpec<"stdc"> {
 
           GuardedFunctionSpec<"f16divf", RetValSpec<Float16Type>, [ArgSpec<FloatType>, ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
 
+          GuardedFunctionSpec<"f16sqrt", RetValSpec<Float16Type>, [ArgSpec<DoubleType>], "LIBC_TYPES_HAS_FLOAT16">,
           GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
+          GuardedFunctionSpec<"f16sqrtl", RetValSpec<Float16Type>, [ArgSpec<LongDoubleType>], "LIBC_TYPES_HAS_FLOAT16">,
+          GuardedFunctionSpec<"f16sqrtf128", RetValSpec<Float16Type>, [ArgSpec<Float128Type>], "LIBC_TYPES_HAS_FLOAT16_AND_FLOAT128">,
       ]
   >;
 
diff --git a/libc/src/__support/FPUtil/generic/CMakeLists.txt b/libc/src/__support/FPUtil/generic/CMakeLists.txt
index 33b2564bfa087..2c386e1cae098 100644
--- a/libc/src/__support/FPUtil/generic/CMakeLists.txt
+++ b/libc/src/__support/FPUtil/generic/CMakeLists.txt
@@ -8,6 +8,7 @@ add_header_library(
     libc.src.__support.common
     libc.src.__support.CPP.bit
     libc.src.__support.CPP.type_traits
+    libc.src.__support.FPUtil.dyadic_float
     libc.src.__support.FPUtil.fenv_impl
     libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.rounding_mode
diff --git a/libc/src/__support/FPUtil/generic/sqrt.h b/libc/src/__support/FPUtil/generic/sqrt.h
index d6e894fdfe021..e9cd3f47eef27 100644
--- a/libc/src/__support/FPUtil/generic/sqrt.h
+++ b/libc/src/__support/FPUtil/generic/sqrt.h
@@ -14,7 +14,7 @@
 #include "src/__support/CPP/type_traits.h"
 #include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
-#include "src/__support/FPUtil/rounding_mode.h"
+#include "src/__support/FPUtil/dyadic_float.h"
 #include "src/__support/common.h"
 #include "src/__support/uint128.h"
 
@@ -78,16 +78,14 @@ sqrt(InType x) {
     return x86::sqrt(x);
   } else {
     // IEEE floating points formats.
-    using OutFPBits = typename fputil::FPBits<OutType>;
-    using OutStorageType = typename OutFPBits::StorageType;
-    using InFPBits = typename fputil::FPBits<InType>;
+    using OutFPBits = FPBits<OutType>;
+    using InFPBits = FPBits<InType>;
     using InStorageType = typename InFPBits::StorageType;
+    using DyadicFloat =
+        DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>;
+
     constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
     constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
-    constexpr int EXTRA_FRACTION_LEN =
-        InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
-    constexpr InStorageType EXTRA_FRACTION_MASK =
-        (InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
 
     InFPBits bits(x);
 
@@ -146,91 +144,19 @@ sqrt(InType x) {
       }
 
       // We compute one more iteration in order to round correctly.
-      bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
-                 0;    // Least significant bit
-      bool rb = false; // Round bit
       r <<= 2;
-      InStorageType tmp = (y << 2) + 1;
+      y <<= 2;
+      InStorageType tmp = y + 1;
       if (r >= tmp) {
         r -= tmp;
-        rb = true;
-      }
-
-      bool sticky = false;
-
-      if constexpr (EXTRA_FRACTION_LEN > 0) {
-        sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
-        rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
-      }
-
-      // Remove hidden bit and append the exponent field.
-      x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
-
-      OutStorageType y_out = static_cast<OutStorageType>(
-          ((y - ONE) >> EXTRA_FRACTION_LEN) |
-          (static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
-
-      if constexpr (EXTRA_FRACTION_LEN > 0) {
-        if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
-          switch (quick_get_round()) {
-          case FE_TONEAREST:
-          case FE_UPWARD:
-            return OutFPBits::inf().get_val();
-          default:
-            return OutFPBits::max_normal().get_val();
-          }
-        }
-
-        if (x_exp <
-            -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
-          switch (quick_get_round()) {
-          case FE_UPWARD:
-            return OutFPBits::min_subnormal().get_val();
-          default:
-            return OutType(0.0);
-          }
-        }
-
-        if (x_exp <= 0) {
-          int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
-          InStorageType underflow_extra_fraction_mask =
-              (InStorageType(1) << underflow_extra_fraction_len) - 1;
-
-          rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
-               0;
-          OutStorageType subnormal_mant =
-              static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
-          lsb = (subnormal_mant & 1) != 0;
-          sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
-
-          switch (quick_get_round()) {
-          case FE_TONEAREST:
-            if (rb && (lsb || sticky))
-              ++subnormal_mant;
-            break;
-          case FE_UPWARD:
-            if (rb || sticky)
-              ++subnormal_mant;
-            break;
-          }
-
-          return cpp::bit_cast<OutType>(subnormal_mant);
-        }
-      }
-
-      switch (quick_get_round()) {
-      case FE_TONEAREST:
-        // Round to nearest, ties to even
-        if (rb && (lsb || (r != 0)))
-          ++y_out;
-        break;
-      case FE_UPWARD:
-        if (rb || (r != 0) || sticky)
-          ++y_out;
-        break;
+        // Rounding bit.
+        y |= 1 << 1;
       }
+      // Sticky bit.
+      y |= r != 0;
 
-      return cpp::bit_cast<OutType>(y_out);
+      DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
+      return yd.template as<OutType, /*ShouldSignalExceptions=*/true>();
     }
   }
 }
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index 711cbf8bbfdca..d914e642fd609 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -103,7 +103,10 @@ add_math_entrypoint_object(f16divf)
 
 add_math_entrypoint_object(f16fmaf)
 
+add_math_entrypoint_object(f16sqrt)
 add_math_entrypoint_object(f16sqrtf)
+add_math_entrypoint_object(f16sqrtl)
+add_math_entrypoint_object(f16sqrtf128)
 
 add_math_entrypoint_object(fabs)
 add_math_entrypoint_object(fabsf)
diff --git a/libc/src/math/f16sqrt.h b/libc/src/math/f16sqrt.h
new file mode 100644
index 0000000000000..f1134ac5ee2b9
--- /dev/null
+++ b/libc/src/math/f16sqrt.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrt -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRT_H
+#define LLVM_LIBC_SRC_MATH_F16SQRT_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrt(double x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRT_H
diff --git a/libc/src/math/f16sqrtf128.h b/libc/src/math/f16sqrtf128.h
new file mode 100644
index 0000000000000..61a6ce9ea5a5d
--- /dev/null
+++ b/libc/src/math/f16sqrtf128.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtf128 -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF128_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTF128_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtf128(float128 x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTF128_H
diff --git a/libc/src/math/f16sqrtl.h b/libc/src/math/f16sqrtl.h
new file mode 100644
index 0000000000000..fd3c55fc95f32
--- /dev/null
+++ b/libc/src/math/f16sqrtl.h
@@ -0,0 +1,20 @@
+//===-- Implementation header for f16sqrtl ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_F16SQRTL_H
+#define LLVM_LIBC_SRC_MATH_F16SQRTL_H
+
+#include "src/__support/macros/properties/types.h"
+
+namespace LIBC_NAMESPACE {
+
+float16 f16sqrtl(long double x);
+
+} // namespace LIBC_NAMESPACE
+
+#endif // LLVM_LIBC_SRC_MATH_F16SQRTL_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index fc2024c89b5df..2ea3ed0dc4d74 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -3757,6 +3757,19 @@ add_entrypoint_object(
     -O3
 )
 
+add_entrypoint_object(
+  f16sqrt
+  SRCS
+    f16sqrt.cpp
+  HDRS
+    ../f16sqrt.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
+
 add_entrypoint_object(
   f16sqrtf
   SRCS
@@ -3769,3 +3782,29 @@ add_entrypoint_object(
   COMPILE_OPTIONS
     -O3
 )
+
+add_entrypoint_object(
+  f16sqrtl
+  SRCS
+    f16sqrtl.cpp
+  HDRS
+    ../f16sqrtl.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
+
+add_entrypoint_object(
+  f16sqrtf128
+  SRCS
+    f16sqrtf128.cpp
+  HDRS
+    ../f16sqrtf128.h
+  DEPENDS
+    libc.src.__support.macros.properties.types
+    libc.src.__support.FPUtil.sqrt
+  COMPILE_OPTIONS
+    -O3
+)
diff --git a/libc/src/math/generic/f16sqrt.cpp b/libc/src/math/generic/f16sqrt.cpp
new file mode 100644
index 0000000000000..9d5f081f63315
--- /dev/null
+++ b/libc/src/math/generic/f16sqrt.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrt function --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrt.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrt, (double x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/f16sqrtf128.cpp b/libc/src/math/generic/f16sqrtf128.cpp
new file mode 100644
index 0000000000000..11a1e8252788e
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtf128.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtf128 function ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtf128.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtf128, (float128 x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/math/generic/f16sqrtl.cpp b/libc/src/math/generic/f16sqrtl.cpp
new file mode 100644
index 0000000000000..2aaac9a780f66
--- /dev/null
+++ b/libc/src/math/generic/f16sqrtl.cpp
@@ -0,0 +1,19 @@
+//===-- Implementation of f16sqrtl function -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/f16sqrtl.h"
+#include "src/__support/FPUtil/sqrt.h"
+#include "src/__support/common.h"
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(float16, f16sqrtl, (long double x)) {
+  return fputil::sqrt<float16>(x);
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index ba588662f469e..5d0ae871363ff 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1247,9 +1247,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1259,9 +1260,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1271,9 +1273,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
 )
 
 add_fp_unittest(
@@ -1283,9 +1286,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrtf_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtf
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1298,9 +1302,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrt_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrt
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1313,9 +1318,10 @@ add_fp_unittest(
     libc-math-unittests
   SRCS
     generic_sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
   DEPENDS
     libc.src.math.sqrtl
-    libc.src.__support.FPUtil.fp_bits
     libc.src.__support.FPUtil.generic.sqrt
   COMPILE_OPTIONS
     -O3
@@ -1903,6 +1909,32 @@ add_fp_unittest(
     libc.src.math.f16divf
 )
 
+add_fp_unittest(
+  f16sqrt_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    f16sqrt_test.cpp
+  HDRS
+    SqrtTest.h
+  DEPENDS
+    libc.src.math.f16sqrt
+)
+
+add_fp_unittest(
+  f16sqrtl_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    f16sqrtl_test.cpp
+  HDRS
+    SqrtTest.h
+  DEPENDS
+    libc.src.math.f16sqrtl
+)
+
 add_fp_unittest(
   f16fmaf_test
   NEED_MPFR
diff --git a/libc/test/src/math/SqrtTest.h b/libc/test/src/math/SqrtTest.h
index 1c422e201bb23..770cc94b3b940 100644
--- a/libc/test/src/math/SqrtTest.h
+++ b/libc/test/src/math/SqrtTest.h
@@ -6,51 +6,36 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "src/__support/CPP/bit.h"
 #include "test/UnitTest/FEnvSafeTest.h"
 #include "test/UnitTest/FPMatcher.h"
 #include "test/UnitTest/Test.h"
 #include "utils/MPFRWrapper/MPFRUtils.h"
 
-#include "hdr/math_macros.h"
-
 namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
 
-template <typename T>
+template <typename OutType, typename InType>
 class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
 
-  DECLARE_SPECIAL_CONSTANTS(T)
+  DECLARE_SPECIAL_CONSTANTS(InType)
 
   static constexpr StorageType HIDDEN_BIT =
-      StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<T>::FRACTION_LEN;
+      StorageType(1) << LIBC_NAMESPACE::fputil::FPBits<InType>::FRACTION_LEN;
 
 public:
-  typedef T (*SqrtFunc)(T);
-
-  void test_special_numbers(SqrtFunc func) {
-    ASSERT_FP_EQ(aNaN, func(aNaN));
-    ASSERT_FP_EQ(inf, func(inf));
-    ASSERT_FP_EQ(aNaN, func(neg_inf));
-    ASSERT_FP_EQ(0.0, func(0.0));
-    ASSERT_FP_EQ(-0.0, func(-0.0));
-    ASSERT_FP_EQ(aNaN, func(T(-1.0)));
-    ASSERT_FP_EQ(T(1.0), func(T(1.0)));
-    ASSERT_FP_EQ(T(2.0), func(T(4.0)));
-    ASSERT_FP_EQ(T(3.0), func(T(9.0)));
-  }
+  using SqrtFunc = OutType (*)(InType);
 
   void test_denormal_values(SqrtFunc func) {
     for (StorageType mant = 1; mant < HIDDEN_BIT; mant <<= 1) {
-      FPBits denormal(T(0.0));
+      FPBits denormal(zero);
       denormal.set_mantissa(mant);
-      T x = denormal.get_val();
+      InType x = denormal.get_val();
       EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Sqrt, x, func(x), 0.5);
     }
 
     constexpr StorageType COUNT = 200'001;
     constexpr StorageType STEP = HIDDEN_BIT / COUNT;
     for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
-      T x = LIBC_NAMESPACE::cpp::bit_cast<T>(v);
+      InType x = FPBits(i).get_val();
       EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Sqrt, x, func(x), 0.5);
     }
   }
@@ -59,17 +44,21 @@ class SqrtTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
     con...
[truncated]

@overmighty
Copy link
Member Author

Should probably change fputil::sqrt to handle NaNs correctly.

@lntue lntue self-requested a review June 27, 2024 14:17
@overmighty overmighty force-pushed the libc-math-f16sqrt-variants branch from ecf4315 to 9dc097f Compare June 28, 2024 13:49
@overmighty
Copy link
Member Author

Rebased onto main to catch errors like in #96976 and #97039.

@overmighty overmighty force-pushed the libc-math-f16sqrt-variants branch from 6e21b9d to ccacf01 Compare June 29, 2024 23:14
@overmighty
Copy link
Member Author

Rebased to fix a merge conflict and added these 2 commits:

  • [libc][math][c23] Move f16sqrt{,f,l} specs to llvm_libc_ext.td
  • [libc][math][c23] Disable f16sqrt{l,f128} on AArch64 Linux

@lntue lntue merged commit 6c1c451 into llvm:main Jun 30, 2024
7 checks passed
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants