Skip to content

Commit a03a5b6

Browse files
author
Aviad Cohen
committed
[mlir][scf]: Add value bound between scf for loop yield and result
We can prove that: %result == %init_arg + trip_count * (%yielded_value - %iter_arg). Where trip_count is (ub - lb) / step.
1 parent c23f241 commit a03a5b6

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ struct ForOpInterface
7070
cstr.bound(value) == cstr.getExpr(initArg);
7171
}
7272
}
73+
74+
if (dim.has_value() || isa<BlockArgument>(value))
75+
return;
76+
77+
// `value` is result of `forOp`, we can prove that:
78+
// %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
79+
// Where trip_count is (ub - lb) / step.
80+
AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
81+
AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
82+
AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
83+
AffineExpr tripCountExpr =
84+
AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
85+
AffineExpr oneIterAdvanceExpr =
86+
cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg);
87+
cstr.bound(value) ==
88+
cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
7389
}
7490

7591
void populateBoundsForIndexValue(Operation *op, Value value,

mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,74 @@ func.func @compare_scf_for(%a: index, %b: index, %c: index) {
267267
}
268268
return
269269
}
270+
271+
// -----
272+
273+
func.func @scf_for_result_infer() {
274+
%c0 = arith.constant 0 : index
275+
%c1 = arith.constant 1 : index
276+
%c10 = arith.constant 10 : index
277+
%0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %c0) -> index {
278+
%2 = "test.some_use"() : () -> (i1)
279+
%3 = scf.if %2 -> (index) {
280+
%5 = arith.addi %arg, %c1 : index
281+
scf.yield %5 : index
282+
} else {
283+
scf.yield %arg : index
284+
}
285+
scf.yield %3 : index
286+
}
287+
// expected-remark @below{{true}}
288+
"test.compare"(%0, %c10) {cmp = "LE"} : (index, index) -> ()
289+
return
290+
}
291+
292+
// -----
293+
294+
func.func @scf_for_result_infer_dynamic_init(%i : index) {
295+
%c0 = arith.constant 0 : index
296+
%c1 = arith.constant 1 : index
297+
%c10 = arith.constant 10 : index
298+
%0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %i) -> index {
299+
%2 = "test.some_use"() : () -> (i1)
300+
%3 = scf.if %2 -> (index) {
301+
%5 = arith.addi %arg, %c1 : index
302+
scf.yield %5 : index
303+
} else {
304+
scf.yield %arg : index
305+
}
306+
scf.yield %3 : index
307+
}
308+
%6 = arith.addi %i, %c10 : index
309+
// expected-remark @below{{true}}
310+
"test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
311+
return
312+
}
313+
314+
// -----
315+
316+
func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) {
317+
%c0 = arith.constant 0 : index
318+
%c1 = arith.constant 1 : index
319+
%c2 = arith.constant 2 : index
320+
%c4 = arith.constant 4 : index
321+
%c5 = arith.constant 5 : index
322+
%c10 = arith.constant 10 : index
323+
%0 = scf.for %iv = %c0 to %c10 step %c2 iter_args(%arg = %i) -> index {
324+
%2 = "test.some_use"() : () -> (i1)
325+
%3 = scf.if %2 -> (index) {
326+
%5 = arith.addi %arg, %c1 : index
327+
scf.yield %5 : index
328+
} else {
329+
scf.yield %arg : index
330+
}
331+
scf.yield %3 : index
332+
}
333+
%6 = arith.addi %i, %c5 : index
334+
%7 = arith.addi %i, %c4 : index
335+
// expected-remark @below{{true}}
336+
"test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
337+
// expected-error @below{{unknown}}
338+
"test.compare"(%0, %7) {cmp = "LE"} : (index, index) -> ()
339+
return
340+
}

0 commit comments

Comments
 (0)