Skip to content

Commit 6dcdd66

Browse files
[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Pass stop condition in the constructor
This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed. This change is in preparation of adding support for branches.
1 parent 35886dc commit 6dcdd66

File tree

9 files changed

+90
-63
lines changed

9 files changed

+90
-63
lines changed

mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
2929
struct ScalableValueBoundsConstraintSet
3030
: public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
3131
detail::ValueBoundsConstraintSet> {
32-
ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
33-
unsigned vscaleMax)
34-
: RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
32+
ScalableValueBoundsConstraintSet(
33+
MLIRContext *context,
34+
ValueBoundsConstraintSet::StopConditionFn stopCondition,
35+
unsigned vscaleMin, unsigned vscaleMax)
36+
: RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
37+
vscaleMax(vscaleMax) {};
3538

3639
using RTTIExtends::bound;
3740
using RTTIExtends::StopConditionFn;

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,9 @@ class ValueBoundsConstraintSet
117117
///
118118
/// The first parameter of the function is the shaped value/index-typed
119119
/// value. The second parameter is the dimension in case of a shaped value.
120-
using StopConditionFn =
121-
function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
120+
/// The third parameter is this constraint set.
121+
using StopConditionFn = std::function<bool(
122+
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
122123

123124
/// Compute a bound for the given index-typed value or shape dimension size.
124125
/// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -271,22 +272,20 @@ class ValueBoundsConstraintSet
271272
/// An index-typed value or the dimension of a shaped-type value.
272273
using ValueDim = std::pair<Value, int64_t>;
273274

274-
ValueBoundsConstraintSet(MLIRContext *ctx);
275+
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
275276

276277
/// Populates the constraint set for a value/map without actually computing
277278
/// the bound. Returns the position for the value/map (via the return value
278279
/// and `posOut` output parameter).
279280
int64_t populateConstraintsSet(Value value,
280-
std::optional<int64_t> dim = std::nullopt,
281-
StopConditionFn stopCondition = nullptr);
281+
std::optional<int64_t> dim = std::nullopt);
282282
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
283-
StopConditionFn stopCondition = nullptr,
284283
int64_t *posOut = nullptr);
285284

286285
/// Iteratively process all elements on the worklist until an index-typed
287286
/// value or shaped value meets `stopCondition`. Such values are not processed
288287
/// any further.
289-
void processWorklist(StopConditionFn stopCondition);
288+
void processWorklist();
290289

291290
/// Bound the given column in the underlying constraint set by the given
292291
/// expression.
@@ -333,6 +332,9 @@ class ValueBoundsConstraintSet
333332

334333
/// Builder for constructing affine expressions.
335334
Builder builder;
335+
336+
/// The current stop condition function.
337+
StopConditionFn stopCondition = nullptr;
336338
};
337339

338340
} // namespace mlir

mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
8484
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
8585
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
8686
bool closedUB) {
87-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
87+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
88+
ValueBoundsConstraintSet &cstr) {
8889
// We are trying to reify a bound for `value` in terms of the owning op's
8990
// operands. Construct a stop condition that evaluates to "true" for any SSA
9091
// value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
100101
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
101102
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
102103
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
103-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
104+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
105+
ValueBoundsConstraintSet &cstr) {
104106
return v != value;
105107
};
106108
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,

mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
119119
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
120120
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
121121
bool closedUB) {
122-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
122+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
123+
ValueBoundsConstraintSet &cstr) {
123124
// We are trying to reify a bound for `value` in terms of the owning op's
124125
// operands. Construct a stop condition that evaluates to "true" for any SSA
125126
// value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
135136
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
136137
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
137138
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
138-
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
139+
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
140+
ValueBoundsConstraintSet &cstr) {
139141
return v != value;
140142
};
141143
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
468468
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
469469
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
470470
/*stopCondition=*/
471-
[&](Value v, std::optional<int64_t> d) {
471+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
472472
if (v == forOp.getUpperBound())
473473
return false;
474474
// Compute a bound that is independent of any affine op results.

mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct ForOpInterface
5858
ValueDimList boundOperands;
5959
LogicalResult status = ValueBoundsConstraintSet::computeBound(
6060
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
61-
[&](Value v, std::optional<int64_t> d) {
61+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
6262
// Stop when reaching a block argument of the loop body.
6363
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
6464
return bbArg.getOwner()->getParentOp() == forOp;

mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
4747
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
4848
StopConditionFn stopCondition) {
4949
using namespace presburger;
50-
5150
assert(vscaleMin <= vscaleMax);
52-
ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
53-
vscaleMax);
5451

55-
int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
52+
// No stop condition specified: Keep adding constraints until the worklist
53+
// is empty.
54+
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
55+
mlir::ValueBoundsConstraintSet &cstr) {
56+
return false;
57+
};
58+
59+
ScalableValueBoundsConstraintSet scalableCstr(
60+
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
61+
vscaleMin, vscaleMax);
62+
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
5663

5764
// Project out all variables apart from vscale.
5865
// This should result in constraints in terms of vscale only.
59-
scalableCstr.projectOut(
60-
[&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
66+
auto projectOutFn = [&](ValueDim p) {
67+
return p.first != scalableCstr.getVscaleValue();
68+
};
69+
scalableCstr.projectOut(projectOutFn);
6170

6271
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
6372
scalableCstr.positionToValueDim.size() &&

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
6767
return std::nullopt;
6868
}
6969

70-
ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
71-
: builder(ctx) {}
70+
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
71+
MLIRContext *ctx, StopConditionFn stopCondition)
72+
: builder(ctx), stopCondition(stopCondition) {
73+
assert(stopCondition && "expected non-null stop condition");
74+
}
7275

7376
char ValueBoundsConstraintSet::ID = 0;
7477

@@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) {
193196
return value.getDefiningOp();
194197
}
195198

196-
void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
199+
void ValueBoundsConstraintSet::processWorklist() {
200+
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
197201
while (!worklist.empty()) {
198202
int64_t pos = worklist.front();
199203
worklist.pop();
@@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
214218

215219
// Do not process any further if the stop condition is met.
216220
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
217-
if (stopCondition(value, maybeDim))
221+
if (stopCondition(value, maybeDim, *this)) {
222+
LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
223+
<< " (dim: " << maybeDim << ")\n");
218224
continue;
225+
}
219226

220227
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
221228
// the worklist.
222229
auto valueBoundsOp =
223230
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
231+
LLVM_DEBUG(llvm::dbgs()
232+
<< "Query value bounds for: " << value
233+
<< " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
224234
if (valueBoundsOp) {
225235
if (dim == kIndexValue) {
226236
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
229239
}
230240
continue;
231241
}
242+
LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
232243

233244
// If the op does not implement `ValueBoundsOpInterface`, check if it
234245
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -278,8 +289,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
278289
bool closedUB) {
279290
#ifndef NDEBUG
280291
assertValidValueDim(value, dim);
281-
assert(!stopCondition(value, dim) &&
282-
"stop condition should not be satisfied for starting point");
283292
#endif // NDEBUG
284293

285294
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -289,9 +298,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
289298
// Process the backward slice of `value` (i.e., reverse use-def chain) until
290299
// `stopCondition` is met.
291300
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
292-
ValueBoundsConstraintSet cstr(value.getContext());
301+
ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
302+
assert(!stopCondition(value, dim, cstr) &&
303+
"stop condition should not be satisfied for starting point");
293304
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
294-
cstr.processWorklist(stopCondition);
305+
cstr.processWorklist();
295306

296307
// Project out all variables (apart from `valueDim`) that do not match the
297308
// stop condition.
@@ -301,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
301312
return false;
302313
auto maybeDim =
303314
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
304-
return !stopCondition(p.first, maybeDim);
315+
return !stopCondition(p.first, maybeDim, cstr);
305316
});
306317

307318
// Compute lower and upper bounds for `valueDim`.
@@ -407,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
407418
bool closedUB) {
408419
return computeBound(
409420
resultMap, mapOperands, type, value, dim,
410-
[&](Value v, std::optional<int64_t> d) {
421+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
411422
return llvm::is_contained(dependencies, std::make_pair(v, d));
412423
},
413424
closedUB);
@@ -443,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
443454
// Reify bounds in terms of any independent values.
444455
return computeBound(
445456
resultMap, mapOperands, type, value, dim,
446-
[&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
457+
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
458+
return isIndependent(v);
459+
},
447460
closedUB);
448461
}
449462

@@ -476,43 +489,42 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
476489
presburger::BoundType type, AffineMap map, ValueDimList operands,
477490
StopConditionFn stopCondition, bool closedUB) {
478491
assert(map.getNumResults() == 1 && "expected affine map with one result");
479-
ValueBoundsConstraintSet cstr(map.getContext());
480492

481-
int64_t pos = 0;
482-
if (stopCondition) {
483-
cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
484-
} else {
485-
// No stop condition specified: Keep adding constraints until a bound could
486-
// be computed.
487-
cstr.populateConstraintsSet(
488-
map, operands,
489-
[&](Value v, std::optional<int64_t> dim) {
490-
return cstr.cstr.getConstantBound64(type, pos).has_value();
491-
},
492-
&pos);
493-
}
493+
// Default stop condition if none was specified: Keep adding constraints until
494+
// a bound could be computed.
495+
int64_t pos;
496+
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
497+
ValueBoundsConstraintSet &cstr) {
498+
return cstr.cstr.getConstantBound64(type, pos).has_value();
499+
};
500+
501+
ValueBoundsConstraintSet cstr(
502+
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
503+
cstr.populateConstraintsSet(map, operands, &pos);
504+
494505
// Compute constant bound for `valueDim`.
495506
int64_t ubAdjustment = closedUB ? 0 : 1;
496507
if (auto bound = cstr.cstr.getConstantBound64(type, pos))
497508
return type == BoundType::UB ? *bound + ubAdjustment : *bound;
498509
return failure();
499510
}
500511

501-
int64_t ValueBoundsConstraintSet::populateConstraintsSet(
502-
Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
512+
int64_t
513+
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
514+
std::optional<int64_t> dim) {
503515
#ifndef NDEBUG
504516
assertValidValueDim(value, dim);
505517
#endif // NDEBUG
506518

507519
AffineMap map =
508520
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
509521
Builder(value.getContext()).getAffineDimExpr(0));
510-
return populateConstraintsSet(map, {{value, dim}}, stopCondition);
522+
return populateConstraintsSet(map, {{value, dim}});
511523
}
512524

513-
int64_t ValueBoundsConstraintSet::populateConstraintsSet(
514-
AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
515-
int64_t *posOut) {
525+
int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
526+
ValueDimList operands,
527+
int64_t *posOut) {
516528
assert(map.getNumResults() == 1 && "expected affine map with one result");
517529
int64_t pos = insert(/*isSymbol=*/false);
518530
if (posOut)
@@ -533,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
533545

534546
// Process the backward slice of `operands` (i.e., reverse use-def chain)
535547
// until `stopCondition` is met.
536-
if (stopCondition) {
537-
processWorklist(stopCondition);
538-
} else {
539-
// No stop condition specified: Keep adding constraints until the worklist
540-
// is empty.
541-
processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
542-
}
548+
processWorklist();
543549

544550
return pos;
545551
}

mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
117117

118118
// Prepare stop condition. By default, reify in terms of the op's
119119
// operands. No stop condition is used when a constant was requested.
120-
std::function<bool(Value, std::optional<int64_t>)> stopCondition =
121-
[&](Value v, std::optional<int64_t> d) {
120+
std::function<bool(Value, std::optional<int64_t>,
121+
ValueBoundsConstraintSet & cstr)>
122+
stopCondition = [&](Value v, std::optional<int64_t> d,
123+
ValueBoundsConstraintSet &cstr) {
122124
// Reify in terms of SSA values that are different from `value`.
123125
return v != value;
124126
};
125127
if (reifyToFuncArgs) {
126128
// Reify in terms of function block arguments.
127-
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
129+
stopCondition = [](Value v, std::optional<int64_t> d,
130+
ValueBoundsConstraintSet &cstr) {
128131
auto bbArg = dyn_cast<BlockArgument>(v);
129132
if (!bbArg)
130133
return false;

0 commit comments

Comments
 (0)