Skip to content

Commit b851e07

Browse files
committed
address review comments
Signed-off-by: James Newling <james.newling@gmail.com>
1 parent c8087b0 commit b851e07

File tree

2 files changed

+24
-101
lines changed

2 files changed

+24
-101
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7620,18 +7620,22 @@ namespace {
76207620
/// %out = arith.constant dense<false> : vector<3xi1>.
76217621
/// ```
76227622
///
7623-
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7624-
/// false at ALL indices we fold. If the constant was 1, then
7625-
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
7626-
/// preferring the 'compact' vector.step representation.
7623+
/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7624+
/// is false at ALL indices we fold. If the constant was 1, then
7625+
/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7626+
/// conservatively preferring the 'compact' vector.step representation.
7627+
///
7628+
/// Note: this folder only works for the case where the constant (`%cst` above)
7629+
/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7630+
/// ensure that constants are always second (on the right).
76277631
struct StepCompareFolder : public OpRewritePattern<StepOp> {
76287632
using Base::Base;
76297633

76307634
LogicalResult matchAndRewrite(StepOp stepOp,
76317635
PatternRewriter &rewriter) const override {
76327636
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
76337637

7634-
for (auto &use : stepOp.getResult().getUses()) {
7638+
for (OpOperand &use : stepOp.getResult().getUses()) {
76357639
auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
76367640
if (!cmpiOp)
76377641
continue;
@@ -7644,7 +7648,8 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76447648
// Check that operand 1 is a constant.
76457649
unsigned constOperandNumber = 1;
76467650
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7647-
auto maybeConstValue = getConstantIntValue(otherOperand);
7651+
std::optional<int64_t> maybeConstValue =
7652+
getConstantIntValue(otherOperand);
76487653
if (!maybeConstValue.has_value())
76497654
continue;
76507655

mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Lines changed: 13 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
func.func @ugt_constant_3_lhs() -> vector<3xi1> {
1616
%cst = arith.constant dense<3> : vector<3xindex>
1717
%0 = vector.step : vector<3xindex>
18-
// 3 > [0, 1, 2] => true
18+
// 3 > [0, 1, 2] => [true, true, true] => true for all indices => fold
1919
%1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
2020
return %1 : vector<3xi1>
2121
}
@@ -28,20 +28,7 @@ func.func @ugt_constant_3_lhs() -> vector<3xi1> {
2828
func.func @negative_ugt_constant_2_lhs() -> vector<3xi1> {
2929
%cst = arith.constant dense<2> : vector<3xindex>
3030
%0 = vector.step : vector<3xindex>
31-
// 2 > [0, 1, 2] => not constant
32-
%1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
33-
return %1 : vector<3xi1>
34-
}
35-
36-
// -----
37-
38-
// CHECK-LABEL: @negative_ugt_constant_1_lhs
39-
// CHECK: %[[CMP:.*]] = arith.cmpi
40-
// CHECK: return %[[CMP]]
41-
func.func @negative_ugt_constant_1_lhs() -> vector<3xi1> {
42-
%cst = arith.constant dense<1> : vector<3xindex>
43-
%0 = vector.step : vector<3xindex>
44-
// 1 > [0, 1, 2] => not constant
31+
// 2 > [0, 1, 2] => [true, true, false] => not same for all indices => don't fold
4532
%1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
4633
return %1 : vector<3xi1>
4734
}
@@ -54,7 +41,7 @@ func.func @negative_ugt_constant_1_lhs() -> vector<3xi1> {
5441
func.func @ugt_constant_3_rhs() -> vector<3xi1> {
5542
%cst = arith.constant dense<3> : vector<3xindex>
5643
%0 = vector.step : vector<3xindex>
57-
// [0, 1, 2] > 3 => false
44+
// [0, 1, 2] > 3 => [false, false, false] => false for all indices => fold
5845
%1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
5946
return %1 : vector<3xi1>
6047
}
@@ -81,7 +68,7 @@ func.func @ugt_constant_max_rhs() -> vector<3xi1> {
8168
func.func @ugt_constant_2_rhs() -> vector<3xi1> {
8269
%cst = arith.constant dense<2> : vector<3xindex>
8370
%0 = vector.step : vector<3xindex>
84-
// [0, 1, 2] > 2 => false
71+
// [0, 1, 2] > 2 => [false, false, false] => false for all indices => fold
8572
%1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
8673
return %1 : vector<3xi1>
8774
}
@@ -94,7 +81,7 @@ func.func @ugt_constant_2_rhs() -> vector<3xi1> {
9481
func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> {
9582
%cst = arith.constant dense<1> : vector<3xindex>
9683
%0 = vector.step : vector<3xindex>
97-
// [0, 1, 2] > 1 => not constant
84+
// [0, 1, 2] > 1 => [false, false, true] => not same for all indices => don't fold
9885
%1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
9986
return %1 : vector<3xi1>
10087
}
@@ -105,26 +92,14 @@ func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> {
10592
/// Tests of `uge` (unsigned greater than or equal)
10693
///===------------------------------------===//
10794

108-
// CHECK-LABEL: @uge_constant_3_lhs
109-
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
110-
// CHECK: return %[[CST]] : vector<3xi1>
111-
func.func @uge_constant_3_lhs() -> vector<3xi1> {
112-
%cst = arith.constant dense<3> : vector<3xindex>
113-
%0 = vector.step : vector<3xindex>
114-
// 3 >= [0, 1, 2] => true
115-
%1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
116-
return %1 : vector<3xi1>
117-
}
118-
119-
// -----
12095

12196
// CHECK-LABEL: @uge_constant_2_lhs
12297
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
12398
// CHECK: return %[[CST]] : vector<3xi1>
12499
func.func @uge_constant_2_lhs() -> vector<3xi1> {
125100
%cst = arith.constant dense<2> : vector<3xindex>
126101
%0 = vector.step : vector<3xindex>
127-
// 2 >= [0, 1, 2] => true
102+
// 2 >= [0, 1, 2] => [true, true, true] => true for all indices => fold
128103
%1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
129104
return %1 : vector<3xi1>
130105
}
@@ -137,7 +112,7 @@ func.func @uge_constant_2_lhs() -> vector<3xi1> {
137112
func.func @negative_uge_constant_1_lhs() -> vector<3xi1> {
138113
%cst = arith.constant dense<1> : vector<3xindex>
139114
%0 = vector.step : vector<3xindex>
140-
// 1 >= [0, 1, 2] => not constant
115+
// 1 >= [0, 1, 2] => [true, false, false] => not same for all indices => don't fold
141116
%1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
142117
return %1 : vector<3xi1>
143118
}
@@ -150,7 +125,7 @@ func.func @negative_uge_constant_1_lhs() -> vector<3xi1> {
150125
func.func @uge_constant_3_rhs() -> vector<3xi1> {
151126
%cst = arith.constant dense<3> : vector<3xindex>
152127
%0 = vector.step : vector<3xindex>
153-
// [0, 1, 2] >= 3 => false
128+
// [0, 1, 2] >= 3 => [false, false, false] => false for all indices => fold
154129
%1 = arith.cmpi uge, %0, %cst : vector<3xindex>
155130
return %1 : vector<3xi1>
156131
}
@@ -163,50 +138,26 @@ func.func @uge_constant_3_rhs() -> vector<3xi1> {
163138
func.func @negative_uge_constant_2_rhs() -> vector<3xi1> {
164139
%cst = arith.constant dense<2> : vector<3xindex>
165140
%0 = vector.step : vector<3xindex>
166-
// [0, 1, 2] >= 2 => not constant
141+
// [0, 1, 2] >= 2 => [false, false, true] => not same for all indices => don't fold
167142
%1 = arith.cmpi uge, %0, %cst : vector<3xindex>
168143
return %1 : vector<3xi1>
169144
}
170145

171146
// -----
172147

173-
// CHECK-LABEL: @negative_uge_constant_1_rhs
174-
// CHECK: %[[CMP:.*]] = arith.cmpi
175-
// CHECK: return %[[CMP]]
176-
func.func @negative_uge_constant_1_rhs() -> vector<3xi1> {
177-
%cst = arith.constant dense<1> : vector<3xindex>
178-
%0 = vector.step : vector<3xindex>
179-
// [0, 1, 2] >= 1 => not constant
180-
%1 = arith.cmpi uge, %0, %cst: vector<3xindex>
181-
return %1 : vector<3xi1>
182-
}
183-
184-
// -----
185-
186-
187148

188149
///===------------------------------------===//
189150
/// Tests of `ult` (unsigned less than)
190151
///===------------------------------------===//
191152

192-
// CHECK-LABEL: @ult_constant_3_lhs
193-
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
194-
// CHECK: return %[[CST]] : vector<3xi1>
195-
func.func @ult_constant_3_lhs() -> vector<3xi1> {
196-
%cst = arith.constant dense<3> : vector<3xindex>
197-
%0 = vector.step : vector<3xindex>
198-
%1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
199-
return %1 : vector<3xi1>
200-
}
201-
202-
// -----
203153

204154
// CHECK-LABEL: @ult_constant_2_lhs
205155
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
206156
// CHECK: return %[[CST]] : vector<3xi1>
207157
func.func @ult_constant_2_lhs() -> vector<3xi1> {
208158
%cst = arith.constant dense<2> : vector<3xindex>
209159
%0 = vector.step : vector<3xindex>
160+
// 2 < [0, 1, 2] => [false, false, false] => false for all indices => fold
210161
%1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
211162
return %1 : vector<3xi1>
212163
}
@@ -219,6 +170,7 @@ func.func @ult_constant_2_lhs() -> vector<3xi1> {
219170
func.func @negative_ult_constant_1_lhs() -> vector<3xi1> {
220171
%cst = arith.constant dense<1> : vector<3xindex>
221172
%0 = vector.step : vector<3xindex>
173+
// 1 < [0, 1, 2] => [false, false, true] => not same for all indices => don't fold
222174
%1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
223175
return %1 : vector<3xi1>
224176
}
@@ -231,6 +183,7 @@ func.func @negative_ult_constant_1_lhs() -> vector<3xi1> {
231183
func.func @ult_constant_3_rhs() -> vector<3xi1> {
232184
%cst = arith.constant dense<3> : vector<3xindex>
233185
%0 = vector.step : vector<3xindex>
186+
// [0, 1, 2] < 3 => [true, true, true] => true for all indices => fold
234187
%1 = arith.cmpi ult, %0, %cst : vector<3xindex>
235188
return %1 : vector<3xi1>
236189
}
@@ -243,24 +196,13 @@ func.func @ult_constant_3_rhs() -> vector<3xi1> {
243196
func.func @negative_ult_constant_2_rhs() -> vector<3xi1> {
244197
%cst = arith.constant dense<2> : vector<3xindex>
245198
%0 = vector.step : vector<3xindex>
199+
// [0, 1, 2] < 2 => [true, true, false] => not same for all indices => don't fold
246200
%1 = arith.cmpi ult, %0, %cst : vector<3xindex>
247201
return %1 : vector<3xi1>
248202
}
249203

250204
// -----
251205

252-
// CHECK-LABEL: @negative_ult_constant_1_rhs
253-
// CHECK: %[[CMP:.*]] = arith.cmpi
254-
// CHECK: return %[[CMP]]
255-
func.func @negative_ult_constant_1_rhs() -> vector<3xi1> {
256-
%cst = arith.constant dense<1> : vector<3xindex>
257-
%0 = vector.step : vector<3xindex>
258-
%1 = arith.cmpi ult, %0, %cst: vector<3xindex>
259-
return %1 : vector<3xi1>
260-
}
261-
262-
// -----
263-
264206
///===------------------------------------===//
265207
/// Tests of `ule` (unsigned less than or equal)
266208
///===------------------------------------===//
@@ -289,30 +231,6 @@ func.func @negative_ule_constant_2_lhs() -> vector<3xi1> {
289231

290232
// -----
291233

292-
// CHECK-LABEL: @negative_ule_constant_1_lhs
293-
// CHECK: %[[CMP:.*]] = arith.cmpi
294-
// CHECK: return %[[CMP]]
295-
func.func @negative_ule_constant_1_lhs() -> vector<3xi1> {
296-
%cst = arith.constant dense<1> : vector<3xindex>
297-
%0 = vector.step : vector<3xindex>
298-
%1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
299-
return %1 : vector<3xi1>
300-
}
301-
302-
// -----
303-
304-
// CHECK-LABEL: @ule_constant_3_rhs
305-
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
306-
// CHECK: return %[[CST]] : vector<3xi1>
307-
func.func @ule_constant_3_rhs() -> vector<3xi1> {
308-
%cst = arith.constant dense<3> : vector<3xindex>
309-
%0 = vector.step : vector<3xindex>
310-
%1 = arith.cmpi ule, %0, %cst : vector<3xindex>
311-
return %1 : vector<3xi1>
312-
}
313-
314-
// -----
315-
316234
// CHECK-LABEL: @ule_constant_2_rhs
317235
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
318236
// CHECK: return %[[CST]] : vector<3xi1>

0 commit comments

Comments
 (0)