Skip to content
34 changes: 25 additions & 9 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
return getLhs();

// Don't fold if it would overflow or if it requires a division by zero.
// TODO: This hook won't fold operations where a = MININT, because
// negating MININT overflows. This can be improved.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](APInt a, const APInt &b) {
Expand All @@ -701,22 +703,36 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
// Both positive, return ceil(a, b).
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
}

// No folding happens if any of the intermediate arithmetic operations
// overflows.
bool overflowNegA = false;
bool overflowNegB = false;
bool overflowDiv = false;
bool overflowNegRes = false;
if (!aGtZero && !bGtZero) {
// Both negative, return ceil(-a, -b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
APInt posA = zero.ssub_ov(a, overflowNegA);
APInt posB = zero.ssub_ov(b, overflowNegB);
APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
return res;
}
if (!aGtZero && bGtZero) {
// A is negative, b is positive, return - ( -a / b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
APInt posA = zero.ssub_ov(a, overflowNegA);
APInt div = posA.sdiv_ov(b, overflowDiv);
APInt res = zero.ssub_ov(div, overflowNegRes);
overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
return res;
}
// A is positive, b is negative, return - (a / -b).
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowNegB);
APInt div = a.sdiv_ov(posB, overflowDiv);
APInt res = zero.ssub_ov(div, overflowNegRes);

overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
return res;
});

return overflowOrDiv0 ? Attribute() : result;
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Transforms/constant-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,44 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {

// -----

// CHECK-LABEL: func @simple_arith.ceildivsi_overflow
func.func @simple_arith.ceildivsi_overflow() -> (i8, i16, i32) {
// The negative values below are MININTs for the corresponding bit-width. The
// folder will try to negate them (so that the division operates on two
// positive numbers), but that would cause overflow (negating MININT
// overflows). Hence folding should not happen and the original ceildivsi is
// preserved.

// TODO: The folder should be able to fold the following by avoiding
// intermediate operations that overflow.

// CHECK-DAG: %[[C_1:.*]] = arith.constant 7 : i8
// CHECK-DAG: %[[MIN_I8:.*]] = arith.constant -128 : i8
// CHECK-DAG: %[[C_2:.*]] = arith.constant 7 : i16
// CHECK-DAG: %[[MIN_I16:.*]] = arith.constant -32768 : i16
// CHECK-DAG: %[[C_3:.*]] = arith.constant 7 : i32
// CHECK-DAG: %[[MIN_I32:.*]] = arith.constant -2147483648 : i32

// CHECK-NEXT: %[[CEILDIV_1:.*]] = arith.ceildivsi %[[MIN_I8]], %[[C_1]] : i8
%0 = arith.constant 7 : i8
%min_int_i8 = arith.constant -128 : i8
%2 = arith.ceildivsi %min_int_i8, %0 : i8

// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I16]], %[[C_2]] : i16
%3 = arith.constant 7 : i16
%min_int_i16 = arith.constant -32768 : i16
%5 = arith.ceildivsi %min_int_i16, %3 : i16

// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I32]], %[[C_3]] : i32
%6 = arith.constant 7 : i32
%min_int_i32 = arith.constant -2147483648 : i32
%8 = arith.ceildivsi %min_int_i32, %6 : i32

return %2, %5, %8 : i8, i16, i32
}

// -----

// CHECK-LABEL: func @simple_arith.ceildivui
func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
Expand Down