Skip to content

Commit 58201ac

Browse files
committed
fix: emit warning for inexact result of floating point binary arithmetic operations
1 parent d22034e commit 58201ac

File tree

2 files changed

+32
-34
lines changed

2 files changed

+32
-34
lines changed

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cassert>
22
#include <cstdint>
3+
#include <iostream>
34
#include <llvm/ADT/APFloat.h>
45
#include <llvm/ADT/APInt.h>
56
#include <llvm/ADT/APSInt.h>
@@ -102,18 +103,12 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
102103
} else {
103104
llvm::llvm_unreachable_internal(
104105
"encountered an unsupported unary operator");
105-
return failure();
106106
}
107107
return success();
108108
}
109109

110110
if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
111-
// auto floatType = operandFloatAttr.getType();
112-
113111
if constexpr (T == UnaryOpKind::exp2) {
114-
// auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
115-
// auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
116-
117112
auto type = operandFloatAttr.getType();
118113

119114
return TypeSwitch<Type, LogicalResult>(type)
@@ -166,9 +161,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
166161

167162
if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
168163
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
169-
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType()) {
164+
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType())
170165
return failure();
171-
}
172166

173167
auto integerType = lhsIntAttr.getType();
174168
APInt resultAPInt;
@@ -211,7 +205,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
211205
resultAPInt = lhsIntAttr.getValue().srem(rhsIntAttr.getValue());
212206
}
213207
} else {
214-
assert(false && "Unsupported binary operator");
208+
llvm::llvm_unreachable_internal(
209+
"encounter an unsupported binary operator.");
215210
}
216211

217212
if (isOverflow)
@@ -223,9 +218,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
223218

224219
if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
225220
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
226-
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType()) {
221+
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType())
227222
return failure();
228-
}
229223

230224
APFloat lhsVal = lhsFloatAttr.getValue();
231225
APFloat rhsVal = rhsFloatAttr.getValue();
@@ -248,13 +242,19 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
248242
} else if constexpr (T == BinaryOpKind::mod) {
249243
operationStatus = resultVal.mod(rhsVal);
250244
} else {
251-
assert(false && "Unsupported binary operator");
245+
llvm::llvm_unreachable_internal(
246+
"encounter an unsupported binary operator.");
252247
}
253248

254249
if (operationStatus != APFloat::opOK) {
255-
return failure();
256-
}
250+
if (operationStatus != APFloat::opInexact)
251+
return failure();
257252

253+
emitWarning(rewriter.getUnknownLoc())
254+
<< "Binary arithmetic operation between " << lhsVal.convertToFloat()
255+
<< " and " << rhsVal.convertToFloat()
256+
<< " produced an inexact result";
257+
}
258258
results.push_back(rewriter.getFloatAttr(floatType, resultVal));
259259
return success();
260260
}

mlir/unittests/Dialect/PDL/BuiltinTest.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,19 @@ TEST_F(BuiltinTest, div) {
253253
"Divide by zero?");
254254
}
255255

256-
auto smallF16 = rewriter.getF16FloatAttr(0.0001);
256+
auto BF16Type = rewriter.getBF16Type();
257+
auto oneBF16 = rewriter.getFloatAttr(BF16Type, 1.0);
258+
auto nineBF16 = rewriter.getFloatAttr(BF16Type, 9.0);
259+
260+
// float: inexact result
261+
// return success(), but warning is emitted.
262+
{
263+
TestPDLResultList results(1);
264+
EXPECT_TRUE(
265+
builtin::div(rewriter, results, {oneBF16, nineBF16}).succeeded());
266+
}
267+
257268
auto twoF16 = rewriter.getF16FloatAttr(2.0);
258-
auto maxValF16 = rewriter.getF16FloatAttr(
259-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
260269
auto zeroF16 = rewriter.getF16FloatAttr(0.0);
261270
auto negzeroF16 = rewriter.getF16FloatAttr(-0.0);
262271

@@ -272,13 +281,6 @@ TEST_F(BuiltinTest, div) {
272281
EXPECT_TRUE(builtin::div(rewriter, results, {twoF16, negzeroF16}).failed());
273282
}
274283

275-
// float: overflow
276-
{
277-
TestPDLResultList results(1);
278-
EXPECT_TRUE(
279-
builtin::div(rewriter, results, {maxValF16, smallF16}).failed());
280-
}
281-
282284
// float: correctness
283285
{
284286
TestPDLResultList results(1);
@@ -456,19 +458,17 @@ TEST_F(BuiltinTest, add) {
456458
EXPECT_TRUE(builtin::add(rewriter, results, {oneI16, oneI32}).failed());
457459
}
458460

459-
auto oneF16 = rewriter.getF16FloatAttr(1.0);
460461
auto oneF32 = rewriter.getF32FloatAttr(1.0);
461462
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
462463
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
463464
auto zeroF64 = rewriter.getF64FloatAttr(0.0);
464-
465-
auto maxValF16 = rewriter.getF16FloatAttr(
466-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
465+
auto overflowF16 = rewriter.getF16FloatAttr(32768);
467466

468467
// float: overflow
469468
{
470469
TestPDLResultList results(1);
471-
EXPECT_TRUE(builtin::add(rewriter, results, {oneF16, maxValF16}).failed());
470+
EXPECT_TRUE(
471+
builtin::add(rewriter, results, {overflowF16, overflowF16}).failed());
472472
}
473473

474474
// float: correctness
@@ -553,19 +553,17 @@ TEST_F(BuiltinTest, sub) {
553553
EXPECT_TRUE(builtin::sub(rewriter, results, {oneI16, oneI32}).failed());
554554
}
555555

556-
auto oneF16 = rewriter.getF16FloatAttr(1.0);
556+
auto oneF16 = rewriter.getF16FloatAttr(100.0);
557557
auto oneF32 = rewriter.getF32FloatAttr(1.0);
558558
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
559559
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
560560
auto zeroF64 = rewriter.getF64FloatAttr(0.0);
561-
562-
auto maxValF16 = rewriter.getF16FloatAttr(
563-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
561+
auto minValF16 = rewriter.getF16FloatAttr(-65504);
564562

565563
// float: overflow
566564
{
567565
TestPDLResultList results(1);
568-
EXPECT_TRUE(builtin::sub(rewriter, results, {maxValF16, oneF16}).failed());
566+
EXPECT_TRUE(builtin::sub(rewriter, results, {oneF16, minValF16}).failed());
569567
}
570568

571569
// float: correctness

0 commit comments

Comments
 (0)