Skip to content

Commit 3920004

Browse files
committed
[mlir][arith] fix ceildivsi lowering
This commit fix the overflow in the case of ceildivsi( <signed_type_min>, <positive_integer> )
1 parent be1958f commit 3920004

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
6060

6161
/// Expands CeilDivSIOp (n, m) into
6262
/// 1) x = (m > 0) ? -1 : 1
63-
/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
63+
/// 2) (n*m>0) ? ((n+x) / m) + 1 : n / m
6464
struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
6565
using OpRewritePattern::OpRewritePattern;
6666
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -80,10 +80,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
8080
Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
8181
Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
8282
Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83-
// Compute negative res: - ((-a)/b).
84-
Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85-
Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86-
Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
83+
// Compute negative res: a/b.
84+
Value negRes = rewriter.create<arith::DivSIOp>(loc, a, b);
8785
// Result is (a*b>0) ? pos result : neg result.
8886
// Note, we want to avoid using a*b because of possible overflow.
8987
// The case that matters are a>0, a==0, a<0, b>0 and b<0. We do

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
1515
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
1616
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
1717
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
18-
// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
19-
// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
20-
// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
18+
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
2119
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
2220
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
2321
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
2422
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
2523
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
2624
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
2725
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
28-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
26+
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
2927
}
3028

3129
// -----
@@ -45,17 +43,15 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
4543
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
4644
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
4745
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
48-
// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
49-
// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
50-
// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
46+
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
5147
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
5248
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
5349
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
5450
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
5551
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
5652
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
5753
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
58-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
54+
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
5955
}
6056

6157
// -----

0 commit comments

Comments
 (0)