Skip to content

Commit c241567

Browse files
committed
[mlir][arith] fix wrong floordivsi fold (#83079)
1) fix floordivsi error expand logic 2) fix floordivsi fold did't check overflow stat Fixs #83079
1 parent b0d1e32 commit c241567

File tree

5 files changed

+93
-115
lines changed

5 files changed

+93
-115
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -689,43 +689,17 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
689689
return getLhs();
690690

691691
// Don't fold if it would overflow or if it requires a division by zero.
692-
bool overflowOrDiv0 = false;
692+
bool overflowOrDiv = false;
693693
auto result = constFoldBinaryOp<IntegerAttr>(
694694
adaptor.getOperands(), [&](APInt a, const APInt &b) {
695-
if (overflowOrDiv0 || !b) {
696-
overflowOrDiv0 = true;
695+
if (b.isZero()) {
696+
overflowOrDiv = true;
697697
return a;
698698
}
699-
if (!a)
700-
return a;
701-
// After this point we know that neither a or b are zero.
702-
unsigned bits = a.getBitWidth();
703-
APInt zero = APInt::getZero(bits);
704-
bool aGtZero = a.sgt(zero);
705-
bool bGtZero = b.sgt(zero);
706-
if (aGtZero && bGtZero) {
707-
// Both positive, return a / b.
708-
return a.sdiv_ov(b, overflowOrDiv0);
709-
}
710-
if (!aGtZero && !bGtZero) {
711-
// Both negative, return -a / -b.
712-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
714-
return posA.sdiv_ov(posB, overflowOrDiv0);
715-
}
716-
if (!aGtZero && bGtZero) {
717-
// A is negative, b is positive, return - ceil(-a, b).
718-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
719-
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
720-
return zero.ssub_ov(ceil, overflowOrDiv0);
721-
}
722-
// A is positive, b is negative, return - ceil(a, -b).
723-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
724-
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
725-
return zero.ssub_ov(ceil, overflowOrDiv0);
699+
return a.sfloordiv_ov(b, overflowOrDiv);
726700
});
727701

728-
return overflowOrDiv0 ? Attribute() : result;
702+
return overflowOrDiv ? Attribute() : result;
729703
}
730704

731705
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
110110
}
111111
};
112112

113-
/// Expands FloorDivSIOp (n, m) into
114-
/// 1) x = (m<0) ? 1 : -1
115-
/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
113+
/// Expands FloorDivSIOp (x, y) into
114+
/// z = x / y
115+
/// if (z * y != x && (x < 0) != (y < 0)) {
116+
/// return z - 1;
117+
/// } else {
118+
/// return z;
119+
/// }
116120
struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
117121
using OpRewritePattern::OpRewritePattern;
118122
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
@@ -121,41 +125,29 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
121125
Type type = op.getType();
122126
Value a = op.getLhs();
123127
Value b = op.getRhs();
124-
Value plusOne = createConst(loc, type, 1, rewriter);
128+
129+
Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130+
Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131+
Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132+
loc, arith::CmpIPredicate::ne, a, product);
125133
Value zero = createConst(loc, type, 0, rewriter);
126-
Value minusOne = createConst(loc, type, -1, rewriter);
127-
// Compute x = (b<0) ? 1 : -1.
128-
Value compare =
129-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
130-
Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
131-
// Compute negative res: -1 - ((x-a)/b).
132-
Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
133-
Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
134-
Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
135-
// Compute positive res: a/b.
136-
Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
137-
// Result is (a*b<0) ? negative result : positive result.
138-
// Note, we want to avoid using a*b because of possible overflow.
139-
// The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
140-
// not particuliarly care if a*b<0 is true or false when b is zero
141-
// as this will result in an illegal divide. So `a*b<0` can be reformulated
142-
// as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
143-
// We pick the first expression here.
134+
144135
Value aNeg =
145136
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
146-
Value aPos =
147-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
148137
Value bNeg =
149138
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
150-
Value bPos =
151-
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
152-
Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
153-
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
154-
Value compareRes =
155-
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
156-
// Perform substitution and return success.
157-
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
158-
posRes);
139+
140+
Value signOpposite = rewriter.create<arith::CmpIOp>(
141+
loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142+
Value cond =
143+
rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144+
145+
Value minusOne = createConst(loc, type, -1, rewriter);
146+
Value quotientMinusOne =
147+
rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
148+
149+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150+
quotient);
159151
return success();
160152
}
161153
};

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,17 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
6666
func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
6767
%res = arith.floordivsi %arg0, %arg1 : i32
6868
return %res : i32
69-
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
70-
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
71-
// CHECK: [[MIN1:%.+]] = arith.constant -1 : i32
72-
// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
73-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32
74-
// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32
75-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
76-
// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32
77-
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
78-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
79-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
80-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
81-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
82-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
83-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
84-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
85-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
69+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32
70+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32
71+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32
72+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32
73+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32
74+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32
75+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
76+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
77+
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant -1 : i32
78+
// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : i32
79+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
8680
}
8781

8882
// -----
@@ -93,23 +87,17 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
9387
func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
9488
%res = arith.floordivsi %arg0, %arg1 : index
9589
return %res : index
96-
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
97-
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
98-
// CHECK: [[MIN1:%.+]] = arith.constant -1 : index
99-
// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
100-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index
101-
// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
102-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
103-
// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
104-
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
105-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
106-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
107-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
108-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
109-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
110-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
111-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
112-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
90+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index
91+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index
92+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index
93+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
94+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index
95+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index
96+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
97+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
98+
// CHECK: %[[NEG_ONE:.*]] = arith.constant -1 : index
99+
// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : index
100+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
113101
}
114102

115103
// -----
@@ -121,23 +109,17 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
121109
func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
122110
%res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
123111
return %res : vector<4xi32>
124-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
125-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
126-
// CHECK: %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
127-
// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
128-
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
129-
// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
130-
// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
131-
// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
132-
// CHECK: %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
133-
// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
134-
// CHECK: %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
135-
// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
136-
// CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
137-
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
138-
// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
139-
// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
140-
// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
112+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32>
113+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32>
114+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32>
115+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32>
116+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32>
117+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32>
118+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
119+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
120+
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
121+
// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
122+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
141123
}
142124

143125
// -----

mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
33
// RUN: -shared-libs=%mlir_c_runner_utils | \
44
// RUN: FileCheck %s
5+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,lower-affine,convert-scf-to-cf,memref-expand,arith-expand),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" | \
6+
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
7+
// RUN: -shared-libs=%mlir_c_runner_utils | \
8+
// RUN: FileCheck %s --check-prefix=SCHECK
59

610
func.func @transfer_read_2d(%A : memref<40xi32>, %base1: index) {
711
%i42 = arith.constant -42: i32
@@ -101,3 +105,20 @@ func.func @entry() {
101105
// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
102106
// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
103107
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
108+
109+
// -----
110+
111+
func.func @non_inline_function() -> (i64, i64) {
112+
%MIN_INT_MINUS_ONE = arith.constant -9223372036854775807 : i64
113+
%NEG_ONE = arith.constant -1 : i64
114+
return %MIN_INT_MINUS_ONE, %NEG_ONE : i64, i64
115+
}
116+
117+
func.func @main() {
118+
%0:2 = call @non_inline_function() : () -> (i64, i64)
119+
%1 = arith.floordivsi %0#0, %0#1 : i64
120+
vector.print %1 : i64
121+
return
122+
}
123+
124+
// SCHECK: 9223372036854775807

mlir/test/Transforms/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x
989989
return %res : tensor<4x5xi32>
990990
}
991991

992+
// CHECK-LABEL: func @arith.floordivsi_by_one_overflow
993+
func.func @arith.floordivsi_by_one_overflow() -> i64 {
994+
%neg_one = arith.constant -1 : i64
995+
%min_int = arith.constant -9223372036854775808 : i64
996+
// CHECK: arith.floordivsi
997+
%poision = arith.floordivsi %min_int, %neg_one : i64
998+
return %poision : i64
999+
}
1000+
9921001
// -----
9931002

9941003
// CHECK-LABEL: func @arith.ceildivsi_by_one

0 commit comments

Comments
 (0)