Skip to content

Commit c6480bc

Browse files
committed
[mlir][arith] Fix overflow bug in arith::CeilDivSIOp::fold
The folder for arith::CeilDivSIOp should only be applied when it can be guaranteed that no overflow would happen. The current implementation works fine when both dividends are positive and the only arithmetic operation is the division itself. However, in cases where at least one of the dividends is negative, the division is split into multiple operations, e.g.: `- ( -a / b)`. That's additional 2 operations on top of the actual division that can overflow - the folder should check all 3 ops for overflow. The current logic doesn't do that - it effectively only the last operation (i.e. the division). It breaks when using e.g. MININT values (e.g. -128 for 8-bit integers) - negating such values overflows. This PR makes sure that no folding happens if any of the intermediate arithmetic operations overflows.
1 parent 33e16ca commit c6480bc

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -701,22 +701,35 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
701701
// Both positive, return ceil(a, b).
702702
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
703703
}
704+
705+
bool overflowNegA = false;
706+
bool overflowNegB = false;
707+
bool overflowNegDiv = false;
708+
bool overflowDiv = false;
704709
if (!aGtZero && !bGtZero) {
705710
// Both negative, return ceil(-a, -b).
706-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
707-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
708-
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
711+
APInt posA = zero.ssub_ov(a, overflowNegA);
712+
APInt posB = zero.ssub_ov(b, overflowNegB);
713+
APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
714+
overflowOrDiv0 =
715+
(overflowNegA || overflowNegB || overflowDiv);
716+
return res;
709717
}
710718
if (!aGtZero && bGtZero) {
711719
// A is negative, b is positive, return - ( -a / b).
712-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713-
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
714-
return zero.ssub_ov(div, overflowOrDiv0);
720+
APInt posA = zero.ssub_ov(a, overflowNegA);
721+
APInt div = posA.sdiv_ov(b, overflowDiv);
722+
APInt res = zero.ssub_ov(div, overflowNegDiv);
723+
overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegDiv);
724+
return res;
715725
}
716726
// A is positive, b is negative, return - (a / -b).
717-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
718-
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
719-
return zero.ssub_ov(div, overflowOrDiv0);
727+
APInt posB = zero.ssub_ov(b, overflowNegB);
728+
APInt div = a.sdiv_ov(posB, overflowDiv);
729+
APInt res = zero.ssub_ov(div, overflowNegDiv);
730+
731+
overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegDiv);
732+
return res;
720733
});
721734

722735
return overflowOrDiv0 ? Attribute() : result;

mlir/test/Transforms/constant-fold.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,26 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
478478

479479
// -----
480480

481+
// CHECK-LABEL: func @simple_arith.ceildivsi_overflow
482+
func.func @simple_arith.ceildivsi_overflow() -> (i8, i16, i32) {
483+
// CHECK-COUNT-3: arith.ceildivsi
484+
%0 = arith.constant 7 : i8
485+
%1 = arith.constant -128 : i8
486+
%2 = arith.ceildivsi %1, %0 : i8
487+
488+
%3 = arith.constant 7 : i16
489+
%4 = arith.constant -32768 : i16
490+
%5 = arith.ceildivsi %4, %3 : i16
491+
492+
%6 = arith.constant 7 : i32
493+
%7 = arith.constant -2147483648 : i32
494+
%8 = arith.ceildivsi %7, %6 : i32
495+
496+
return %2, %5, %8 : i8, i16, i32
497+
}
498+
499+
// -----
500+
481501
// CHECK-LABEL: func @simple_arith.ceildivui
482502
func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
483503
// CHECK-DAG: [[C0:%.+]] = arith.constant 0

0 commit comments

Comments
 (0)