Skip to content

Commit 1fff7f9

Browse files
committed
[mlir][arith] fix wrong floordivsi fold (#83079)
Fixs #83079
1 parent 512a8a7 commit 1fff7f9

File tree

5 files changed

+55
-33
lines changed

5 files changed

+55
-33
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ class [[nodiscard]] APInt {
996996
APInt sshl_ov(unsigned Amt, bool &Overflow) const;
997997
APInt ushl_ov(const APInt &Amt, bool &Overflow) const;
998998
APInt ushl_ov(unsigned Amt, bool &Overflow) const;
999+
APInt sfloordiv_ov(const APInt &RHS, bool &Overflow) const;
9991000

10001001
// Operations that saturate
10011002
APInt sadd_sat(const APInt &RHS) const;

llvm/lib/Support/APInt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,14 @@ APInt APInt::ushl_ov(unsigned ShAmt, bool &Overflow) const {
20222022
return *this << ShAmt;
20232023
}
20242024

2025+
APInt APInt::sfloordiv_ov(const APInt &RHS, bool &Overflow) const {
2026+
auto quotient = sdiv_ov(RHS, Overflow);
2027+
if ((quotient * RHS != *this) && (isNegative() != RHS.isNegative()))
2028+
return quotient - 1;
2029+
else
2030+
return quotient;
2031+
}
2032+
20252033
APInt APInt::sadd_sat(const APInt &RHS) const {
20262034
bool Overflow;
20272035
APInt Res = sadd_ov(RHS, Overflow);

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/Support/Alignment.h"
1515
#include "gtest/gtest.h"
1616
#include <array>
17+
#include <limits>
1718
#include <optional>
1819

1920
using namespace llvm;
@@ -2893,6 +2894,39 @@ TEST(APIntTest, smul_ov) {
28932894
}
28942895
}
28952896

2897+
TEST(APIntTest, sfloordiv_ov) {
2898+
{
2899+
APInt divisor(32, -3, true);
2900+
APInt dividend(32, 2, true);
2901+
bool Overflow = false;
2902+
auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
2903+
EXPECT_FALSE(Overflow);
2904+
EXPECT_EQ(-2, quotient.getSExtValue());
2905+
}
2906+
{
2907+
APInt divisor(32, std::numeric_limits<int>::lowest(), true);
2908+
APInt dividend(32, -1, true);
2909+
bool Overflow = false;
2910+
[[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
2911+
EXPECT_TRUE(Overflow);
2912+
}
2913+
{
2914+
auto check_overflow_one = [](auto arg) {
2915+
using IntTy = decltype(arg);
2916+
APInt divisor(8 * sizeof(arg), std::numeric_limits<IntTy>::lowest(),
2917+
true);
2918+
APInt dividend(8 * sizeof(arg), IntTy(-1), true);
2919+
bool Overflow = false;
2920+
[[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
2921+
EXPECT_TRUE(Overflow);
2922+
};
2923+
auto check_overflow_all = [&](auto... args) {
2924+
(void)std::initializer_list<int>{(check_overflow_one(args), 0)...};
2925+
};
2926+
std::apply(check_overflow_all, std::tuple<char, short, int, int64_t>());
2927+
}
2928+
}
2929+
28962930
TEST(APIntTest, SolveQuadraticEquationWrap) {
28972931
// Verify that "Solution" is the first non-negative integer that solves
28982932
// Ax^2 + Bx + C = "0 or overflow", i.e. that it is a correct solution

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -689,43 +689,13 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
689689
return getLhs();
690690

691691
// Don't fold if it would overflow or if it requires a division by zero.
692-
bool overflowOrDiv0 = false;
692+
bool overflowOrDiv = false;
693693
auto result = constFoldBinaryOp<IntegerAttr>(
694694
adaptor.getOperands(), [&](APInt a, const APInt &b) {
695-
if (overflowOrDiv0 || !b) {
696-
overflowOrDiv0 = true;
697-
return a;
698-
}
699-
if (!a)
700-
return a;
701-
// After this point we know that neither a or b are zero.
702-
unsigned bits = a.getBitWidth();
703-
APInt zero = APInt::getZero(bits);
704-
bool aGtZero = a.sgt(zero);
705-
bool bGtZero = b.sgt(zero);
706-
if (aGtZero && bGtZero) {
707-
// Both positive, return a / b.
708-
return a.sdiv_ov(b, overflowOrDiv0);
709-
}
710-
if (!aGtZero && !bGtZero) {
711-
// Both negative, return -a / -b.
712-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
714-
return posA.sdiv_ov(posB, overflowOrDiv0);
715-
}
716-
if (!aGtZero && bGtZero) {
717-
// A is negative, b is positive, return - ceil(-a, b).
718-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
719-
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
720-
return zero.ssub_ov(ceil, overflowOrDiv0);
721-
}
722-
// A is positive, b is negative, return - ceil(a, -b).
723-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
724-
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
725-
return zero.ssub_ov(ceil, overflowOrDiv0);
695+
return a.sfloordiv_ov(b, overflowOrDiv);
726696
});
727697

728-
return overflowOrDiv0 ? Attribute() : result;
698+
return overflowOrDiv ? Attribute() : result;
729699
}
730700

731701
//===----------------------------------------------------------------------===//

mlir/test/Transforms/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x
989989
return %res : tensor<4x5xi32>
990990
}
991991

992+
// CHECK-LABEL: func @arith.floordivsi_by_one_overflow
993+
func.func @arith.floordivsi_by_one_overflow() -> i64 {
994+
%neg_one = arith.constant -1 : i64
995+
%min_int = arith.constant -9223372036854775808 : i64
996+
// CHECK: arith.floordivsi
997+
%poision = arith.floordivsi %min_int, %neg_one : i64
998+
return %poision : i64
999+
}
1000+
9921001
// -----
9931002

9941003
// CHECK-LABEL: func @arith.ceildivsi_by_one

0 commit comments

Comments
 (0)