Skip to content

[mlir][arith] Fix overflow bug in arith::CeilDivSIOp::fold #90947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 8, 2024
34 changes: 25 additions & 9 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
return getLhs();

// Don't fold if it would overflow or if it requires a division by zero.
// TODO: This hook won't fold operations where a = MININT, because
// negating MININT overflows. This can be improved.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](APInt a, const APInt &b) {
Expand All @@ -701,22 +703,36 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
// Both positive, return ceil(a, b).
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
}

// No folding happens if any of the intermediate arithmetic operations
// overflows.
bool overflowNegA = false;
bool overflowNegB = false;
bool overflowDiv = false;
bool overflowNegRes = false;
if (!aGtZero && !bGtZero) {
// Both negative, return ceil(-a, -b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
APInt posA = zero.ssub_ov(a, overflowNegA);
APInt posB = zero.ssub_ov(b, overflowNegB);
APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
return res;
}
if (!aGtZero && bGtZero) {
// A is negative, b is positive, return - ( -a / b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
APInt posA = zero.ssub_ov(a, overflowNegA);
APInt div = posA.sdiv_ov(b, overflowDiv);
APInt res = zero.ssub_ov(div, overflowNegRes);
overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
return res;
}
// A is positive, b is negative, return - (a / -b).
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowNegB);
APInt div = a.sdiv_ov(posB, overflowDiv);
APInt res = zero.ssub_ov(div, overflowNegRes);

overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
return res;
});

return overflowOrDiv0 ? Attribute() : result;
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Transforms/constant-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,44 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {

// -----

// CHECK-LABEL: func @simple_arith.ceildivsi_overflow
func.func @simple_arith.ceildivsi_overflow() -> (i8, i16, i32) {
// The negative values below are MININTs for the corresponding bit-width. The
// folder will try to negate them (so that the division operates on two
// positive numbers), but that would cause overflow (negating MININT
// overflows). Hence folding should not happen and the original ceildivsi is
// preserved.

// TODO: The folder should be able to fold the following by avoiding
// intermediate operations that overflow.

// CHECK-DAG: %[[C_1:.*]] = arith.constant 7 : i8
// CHECK-DAG: %[[MIN_I8:.*]] = arith.constant -128 : i8
// CHECK-DAG: %[[C_2:.*]] = arith.constant 7 : i16
// CHECK-DAG: %[[MIN_I16:.*]] = arith.constant -32768 : i16
// CHECK-DAG: %[[C_3:.*]] = arith.constant 7 : i32
// CHECK-DAG: %[[MIN_I32:.*]] = arith.constant -2147483648 : i32

// CHECK-NEXT: %[[CEILDIV_1:.*]] = arith.ceildivsi %[[MIN_I8]], %[[C_1]] : i8
%0 = arith.constant 7 : i8
%min_int_i8 = arith.constant -128 : i8
%2 = arith.ceildivsi %min_int_i8, %0 : i8

// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I16]], %[[C_2]] : i16
%3 = arith.constant 7 : i16
%min_int_i16 = arith.constant -32768 : i16
%5 = arith.ceildivsi %min_int_i16, %3 : i16

// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I32]], %[[C_3]] : i32
%6 = arith.constant 7 : i32
%min_int_i32 = arith.constant -2147483648 : i32
%8 = arith.ceildivsi %min_int_i32, %6 : i32

return %2, %5, %8 : i8, i16, i32
}

// -----

// CHECK-LABEL: func @simple_arith.ceildivui
func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
Expand Down
Loading