Skip to content

Commit 263be9f

Browse files
author
Job Henandez Lara
authored
[libc][math][c23] fmul correcly rounded to all rounding modes (#91537)
This is an implementation of floating point multiplication: It will consist of - `double x double -> float`
1 parent 1e92ad4 commit 263be9f

File tree

14 files changed

+303
-1
lines changed

14 files changed

+303
-1
lines changed

libc/config/linux/aarch64/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ set(TARGET_LIBM_ENTRYPOINTS
394394
libc.src.math.fminimum_mag_num
395395
libc.src.math.fminimum_mag_numf
396396
libc.src.math.fminimum_mag_numl
397+
libc.src.math.fmul
397398
libc.src.math.fmod
398399
libc.src.math.fmodf
399400
libc.src.math.fmodl

libc/config/linux/arm/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ set(TARGET_LIBM_ENTRYPOINTS
261261
libc.src.math.fminimum_mag_num
262262
libc.src.math.fminimum_mag_numf
263263
libc.src.math.fminimum_mag_numl
264+
libc.src.math.fmul
264265
libc.src.math.fmod
265266
libc.src.math.fmodf
266267
libc.src.math.frexp

libc/config/linux/riscv/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ set(TARGET_LIBM_ENTRYPOINTS
402402
libc.src.math.fminimum_mag_num
403403
libc.src.math.fminimum_mag_numf
404404
libc.src.math.fminimum_mag_numl
405+
libc.src.math.fmul
405406
libc.src.math.fmod
406407
libc.src.math.fmodf
407408
libc.src.math.fmodl

libc/config/linux/x86_64/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ set(TARGET_LIBM_ENTRYPOINTS
421421
libc.src.math.fminimum_mag_num
422422
libc.src.math.fminimum_mag_numf
423423
libc.src.math.fminimum_mag_numl
424+
libc.src.math.fmul
424425
libc.src.math.fmod
425426
libc.src.math.fmodf
426427
libc.src.math.fmodl

libc/config/windows/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ set(TARGET_LIBM_ENTRYPOINTS
180180
libc.src.math.fminimum_mag_num
181181
libc.src.math.fminimum_mag_numf
182182
libc.src.math.fminimum_mag_numl
183+
libc.src.math.fmul
183184
libc.src.math.fmod
184185
libc.src.math.fmodf
185186
libc.src.math.fmodl

libc/docs/math/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Basic Operations
158158
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
159159
| fmod | |check| | |check| | |check| | |check| | |check| | 7.12.10.1 | F.10.7.1 |
160160
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
161-
| fmul | N/A | | | N/A | | 7.12.14.3 | F.10.11 |
161+
| fmul | N/A | |check| | | N/A | | 7.12.14.3 | F.10.11 |
162162
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
163163
| frexp | |check| | |check| | |check| | | |check| | 7.12.6.7 | F.10.3.7 |
164164
+------------------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+

libc/spec/stdc.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def StdC : StandardSpec<"stdc"> {
472472
GuardedFunctionSpec<"fminimum_mag_numf16", RetValSpec<Float16Type>, [ArgSpec<Float16Type>, ArgSpec<Float16Type>], "LIBC_TYPES_HAS_FLOAT16">,
473473
GuardedFunctionSpec<"fminimum_mag_numf128", RetValSpec<Float128Type>, [ArgSpec<Float128Type>, ArgSpec<Float128Type>], "LIBC_TYPES_HAS_FLOAT128">,
474474

475+
FunctionSpec<"fmul", RetValSpec<FloatType>, [ArgSpec<DoubleType>, ArgSpec<DoubleType>]>,
476+
477+
475478
FunctionSpec<"fma", RetValSpec<DoubleType>, [ArgSpec<DoubleType>, ArgSpec<DoubleType>, ArgSpec<DoubleType>]>,
476479
FunctionSpec<"fmaf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>, ArgSpec<FloatType>]>,
477480

libc/src/math/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ add_math_entrypoint_object(fminimum_mag_numl)
180180
add_math_entrypoint_object(fminimum_mag_numf16)
181181
add_math_entrypoint_object(fminimum_mag_numf128)
182182

183+
add_math_entrypoint_object(fmul)
184+
183185
add_math_entrypoint_object(fmod)
184186
add_math_entrypoint_object(fmodf)
185187
add_math_entrypoint_object(fmodl)

libc/src/math/fmul.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===-- Implementation header for fmul --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC_MATH_FMUL_H
10+
#define LLVM_LIBC_SRC_MATH_FMUL_H
11+
12+
namespace LIBC_NAMESPACE {
13+
14+
float fmul(double x, double y);
15+
16+
} // namespace LIBC_NAMESPACE
17+
18+
#endif // LLVM_LIBC_SRC_MATH_FMUL_H

libc/src/math/generic/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,22 @@ add_entrypoint_object(
23542354
-O3
23552355
)
23562356

2357+
add_entrypoint_object(
2358+
fmul
2359+
SRCS
2360+
fmul.cpp
2361+
HDRS
2362+
../fmul.h
2363+
DEPENDS
2364+
libc.src.__support.FPUtil.basic_operations
2365+
libc.src.__support.uint128
2366+
libc.src.__support.CPP.bit
2367+
libc.src.__support.FPUtil.fp_bits
2368+
libc.src.__support.FPUtil.rounding_mode
2369+
COMPILE_OPTIONS
2370+
-O3
2371+
)
2372+
23572373
add_entrypoint_object(
23582374
sqrt
23592375
SRCS

libc/src/math/generic/fmul.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
//===-- Implementation of fmul function------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/math/fmul.h"
10+
#include "src/__support/CPP/bit.h"
11+
#include "src/__support/FPUtil/BasicOperations.h"
12+
#include "src/__support/FPUtil/FPBits.h"
13+
#include "src/__support/FPUtil/rounding_mode.h"
14+
#include "src/__support/common.h"
15+
#include "src/__support/uint128.h"
16+
17+
namespace LIBC_NAMESPACE {
18+
19+
LLVM_LIBC_FUNCTION(float, fmul, (double x, double y)) {
20+
auto x_bits = fputil::FPBits<double>(x);
21+
22+
auto y_bits = fputil::FPBits<double>(y);
23+
24+
auto output_sign = (x_bits.sign() != y_bits.sign()) ? Sign::NEG : Sign::POS;
25+
26+
if (LIBC_UNLIKELY(x_bits.is_inf_or_nan() || y_bits.is_inf_or_nan() ||
27+
x_bits.is_zero() || y_bits.is_zero())) {
28+
if (x_bits.is_nan())
29+
return static_cast<float>(x);
30+
if (y_bits.is_nan())
31+
return static_cast<float>(y);
32+
if (x_bits.is_inf())
33+
return y_bits.is_zero()
34+
? fputil::FPBits<float>::quiet_nan().get_val()
35+
: fputil::FPBits<float>::inf(output_sign).get_val();
36+
if (y_bits.is_inf())
37+
return x_bits.is_zero()
38+
? fputil::FPBits<float>::quiet_nan().get_val()
39+
: fputil::FPBits<float>::inf(output_sign).get_val();
40+
// Now either x or y is zero, and the other one is finite.
41+
return fputil::FPBits<float>::zero(output_sign).get_val();
42+
}
43+
44+
uint64_t mx, my;
45+
46+
// Get mantissa and append the hidden bit if needed.
47+
mx = x_bits.get_explicit_mantissa();
48+
my = y_bits.get_explicit_mantissa();
49+
50+
// Get the corresponding biased exponent.
51+
int ex = x_bits.get_explicit_exponent();
52+
int ey = y_bits.get_explicit_exponent();
53+
54+
// Count the number of leading zeros of the explicit mantissas.
55+
int nx = cpp::countl_zero(mx);
56+
int ny = cpp::countl_zero(my);
57+
// Shift the leading 1 bit to the most significant bit.
58+
mx <<= nx;
59+
my <<= ny;
60+
61+
// Adjust exponent accordingly: If x or y are normal, we will only need to
62+
// shift by (exponent length + sign bit = 11 bits. If x or y are denormal, we
63+
// will need to shift more than 11 bits.
64+
ex -= (nx - 11);
65+
ey -= (ny - 11);
66+
67+
UInt128 product = static_cast<UInt128>(mx) * static_cast<UInt128>(my);
68+
int32_t dm1;
69+
uint64_t highs, lows;
70+
uint64_t g, hight, lowt;
71+
uint32_t m;
72+
uint32_t b;
73+
int c;
74+
75+
highs = static_cast<uint64_t>(product >> 64);
76+
c = static_cast<int>(highs >= 0x8000000000000000);
77+
lows = static_cast<uint64_t>(product);
78+
79+
lowt = (lows != 0);
80+
81+
dm1 = ex + ey + c + fputil::FPBits<float>::EXP_BIAS;
82+
83+
int round_mode = fputil::quick_get_round();
84+
if (dm1 >= 255) {
85+
if ((round_mode == FE_TOWARDZERO) ||
86+
(round_mode == FE_UPWARD && output_sign.is_neg()) ||
87+
(round_mode == FE_DOWNWARD && output_sign.is_pos())) {
88+
return fputil::FPBits<float>::max_normal(output_sign).get_val();
89+
}
90+
return fputil::FPBits<float>::inf().get_val();
91+
} else if (dm1 <= 0) {
92+
93+
int m_shift = 40 + c - dm1;
94+
int g_shift = m_shift - 1;
95+
int h_shift = 64 - g_shift;
96+
m = (m_shift >= 64) ? 0 : static_cast<uint32_t>(highs >> m_shift);
97+
98+
g = g_shift >= 64 ? 0 : (highs >> g_shift) & 1;
99+
hight = h_shift >= 64 ? highs : (highs << h_shift) != 0;
100+
101+
dm1 = 0;
102+
} else {
103+
m = static_cast<uint32_t>(highs >> (39 + c));
104+
g = (highs >> (38 + c)) & 1;
105+
hight = (highs << (26 - c)) != 0;
106+
}
107+
108+
if (round_mode == FE_TONEAREST) {
109+
b = g && ((hight && lowt) || ((m & 1) != 0));
110+
} else if ((output_sign.is_neg() && round_mode == FE_DOWNWARD) ||
111+
(output_sign.is_pos() && round_mode == FE_UPWARD)) {
112+
b = (g == 0 && (hight && lowt) == 0) ? 0 : 1;
113+
} else {
114+
b = 0;
115+
}
116+
117+
uint32_t exp16 = (dm1 << 23);
118+
119+
uint32_t m2 = m & fputil::FPBits<float>::FRACTION_MASK;
120+
121+
uint32_t result = (exp16 + m2) + b;
122+
123+
auto result_bits = fputil::FPBits<float>(result);
124+
result_bits.set_sign(output_sign);
125+
return result_bits.get_val();
126+
}
127+
128+
} // namespace LIBC_NAMESPACE

libc/test/src/math/smoke/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,6 +2400,19 @@ add_fp_unittest(
24002400
libc.src.__support.FPUtil.fp_bits
24012401
)
24022402

2403+
add_fp_unittest(
2404+
fmul_test
2405+
SUITE
2406+
libc-math-smoke-tests
2407+
SRCS
2408+
fmul_test.cpp
2409+
HDRS
2410+
FMulTest.h
2411+
DEPENDS
2412+
libc.src.math.fmul
2413+
libc.src.__support.FPUtil.fp_bits
2414+
)
2415+
24032416
add_fp_unittest(
24042417
sqrtf_test
24052418
SUITE

libc/test/src/math/smoke/FMulTest.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//===-- Utility class to test fmul[f|l] ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_TEST_SRC_MATH_SMOKE_FMULTEST_H
10+
#define LLVM_LIBC_TEST_SRC_MATH_SMOKE_FMULTEST_H
11+
12+
#include "test/UnitTest/FEnvSafeTest.h"
13+
#include "test/UnitTest/FPMatcher.h"
14+
#include "test/UnitTest/Test.h"
15+
16+
template <typename T, typename R>
17+
class FmulTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
18+
19+
DECLARE_SPECIAL_CONSTANTS(T)
20+
21+
public:
22+
typedef T (*FMulFunc)(R, R);
23+
24+
void testMul(FMulFunc func) {
25+
26+
EXPECT_FP_EQ_ALL_ROUNDING(T(15.0), func(3.0, 5.0));
27+
EXPECT_FP_EQ_ALL_ROUNDING(T(0x1.0p-130), func(0x1.0p1, 0x1.0p-131));
28+
EXPECT_FP_EQ_ALL_ROUNDING(T(0x1.0p-127), func(0x1.0p2, 0x1.0p-129));
29+
EXPECT_FP_EQ_ALL_ROUNDING(T(1.0), func(1.0, 1.0));
30+
31+
EXPECT_FP_EQ_ALL_ROUNDING(T(0.0), func(-0.0, -0.0));
32+
EXPECT_FP_EQ_ALL_ROUNDING(T(-0.0), func(0.0, -0.0));
33+
EXPECT_FP_EQ_ALL_ROUNDING(T(-0.0), func(-0.0, 0.0));
34+
35+
EXPECT_FP_EQ_ROUNDING_NEAREST(inf, func(0x1.0p100, 0x1.0p100));
36+
EXPECT_FP_EQ_ROUNDING_UPWARD(inf, func(0x1.0p100, 0x1.0p100));
37+
EXPECT_FP_EQ_ROUNDING_DOWNWARD(max_normal, func(0x1.0p100, 0x1.0p100));
38+
EXPECT_FP_EQ_ROUNDING_TOWARD_ZERO(max_normal, func(0x1.0p100, 0x1.0p100));
39+
40+
EXPECT_FP_EQ_ROUNDING_NEAREST(
41+
0x1p0, func(1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
42+
EXPECT_FP_EQ_ROUNDING_DOWNWARD(
43+
0x1p0, func(1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
44+
EXPECT_FP_EQ_ROUNDING_TOWARD_ZERO(
45+
0x1p0, func(1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
46+
EXPECT_FP_EQ_ROUNDING_UPWARD(
47+
0x1p0, func(1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
48+
49+
EXPECT_FP_EQ_ROUNDING_NEAREST(
50+
0x1.0p-128f + 0x1.0p-148f,
51+
func(1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
52+
EXPECT_FP_EQ_ROUNDING_UPWARD(
53+
0x1.0p-128f + 0x1.0p-148f,
54+
func(1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
55+
EXPECT_FP_EQ_ROUNDING_DOWNWARD(
56+
0x1.0p-128f + 0x1.0p-149f,
57+
func(1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
58+
EXPECT_FP_EQ_ROUNDING_TOWARD_ZERO(
59+
0x1.0p-128f + 0x1.0p-149f,
60+
func(1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150));
61+
}
62+
63+
void testSpecialInputs(FMulFunc func) {
64+
EXPECT_FP_EQ_ALL_ROUNDING(inf, func(inf, 0x1.0p-129));
65+
EXPECT_FP_EQ_ALL_ROUNDING(inf, func(0x1.0p-129, inf));
66+
EXPECT_FP_EQ_ALL_ROUNDING(inf, func(inf, 2.0));
67+
EXPECT_FP_EQ_ALL_ROUNDING(inf, func(3.0, inf));
68+
EXPECT_FP_EQ_ALL_ROUNDING(0.0, func(0.0, 0.0));
69+
70+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(neg_inf, aNaN));
71+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(aNaN, neg_inf));
72+
EXPECT_FP_EQ_ALL_ROUNDING(inf, func(neg_inf, neg_inf));
73+
74+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0.0, neg_inf));
75+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(neg_inf, 0.0));
76+
77+
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, func(neg_inf, 1.0));
78+
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, func(1.0, neg_inf));
79+
80+
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, func(neg_inf, 0x1.0p-129));
81+
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, func(0x1.0p-129, neg_inf));
82+
83+
EXPECT_FP_EQ_ALL_ROUNDING(0.0, func(0.0, 0x1.0p-129));
84+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(inf, 0.0));
85+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0.0, inf));
86+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0.0, aNaN));
87+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(2.0, aNaN));
88+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0x1.0p-129, aNaN));
89+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(inf, aNaN));
90+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(aNaN, aNaN));
91+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0.0, sNaN));
92+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(2.0, sNaN));
93+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(0x1.0p-129, sNaN));
94+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(inf, sNaN));
95+
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, func(sNaN, sNaN));
96+
}
97+
};
98+
99+
#define LIST_FMUL_TESTS(T, R, func) \
100+
using LlvmLibcFmulTest = FmulTest<T, R>; \
101+
TEST_F(LlvmLibcFmulTest, Mul) { testMul(&func); } \
102+
TEST_F(LlvmLibcFmulTest, NaNInf) { testSpecialInputs(&func); }
103+
104+
#endif // LLVM_LIBC_TEST_SRC_MATH_SMOKE_FMULTEST_H
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===-- Unittests for fmul-------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#include "FMulTest.h"
10+
11+
#include "src/math/fmul.h"
12+
13+
LIST_FMUL_TESTS(float, double, LIBC_NAMESPACE::fmul)

0 commit comments

Comments
 (0)