Skip to content

Commit 35f77dc

Browse files
committed
[mlir] add a fluent API to GreedyRewriterConfig
This is similar to other configuration objects used across MLIR.
1 parent b0b97e3 commit 35f77dc

File tree

11 files changed

+83
-45
lines changed

11 files changed

+83
-45
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,41 @@ class GreedyRewriteConfig {
5050
///
5151
/// Note: Only applicable when simplifying entire regions.
5252
bool useTopDownTraversal = false;
53+
GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
54+
useTopDownTraversal = use;
55+
return *this;
56+
}
5357

5458
/// Perform control flow optimizations to the region tree after applying all
5559
/// patterns.
5660
///
5761
/// Note: Only applicable when simplifying entire regions.
5862
GreedySimplifyRegionLevel enableRegionSimplification =
5963
GreedySimplifyRegionLevel::Aggressive;
64+
GreedyRewriteConfig &
65+
setEnableRegionSimplification(GreedySimplifyRegionLevel level) {
66+
enableRegionSimplification = level;
67+
return *this;
68+
}
6069

6170
/// This specifies the maximum number of times the rewriter will iterate
6271
/// between applying patterns and simplifying regions. Use `kNoLimit` to
6372
/// disable this iteration limit.
6473
///
6574
/// Note: Only applicable when simplifying entire regions.
6675
int64_t maxIterations = 10;
76+
GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
77+
maxIterations = iterations;
78+
return *this;
79+
}
6780

6881
/// This specifies the maximum number of rewrites within an iteration. Use
6982
/// `kNoLimit` to disable this limit.
7083
int64_t maxNumRewrites = kNoLimit;
84+
GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
85+
maxNumRewrites = limit;
86+
return *this;
87+
}
7188

7289
static constexpr int64_t kNoLimit = -1;
7390

@@ -76,6 +93,10 @@ class GreedyRewriteConfig {
7693
/// (or the specified region, depending on which greedy rewrite entry point
7794
/// is used) is used as a scope.
7895
Region *scope = nullptr;
96+
GreedyRewriteConfig &setScope(Region *scope) {
97+
this->scope = scope;
98+
return *this;
99+
}
79100

80101
/// Strict mode can restrict the ops that are added to the worklist during
81102
/// the rewrite.
@@ -88,16 +109,32 @@ class GreedyRewriteConfig {
88109
/// were on the worklist at the very beginning) enqueued. All other ops are
89110
/// excluded.
90111
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
112+
GreedyRewriteConfig &setStrictMode(GreedyRewriteStrictness mode) {
113+
strictMode = mode;
114+
return *this;
115+
}
91116

92117
/// An optional listener that should be notified about IR modifications.
93118
RewriterBase::Listener *listener = nullptr;
119+
GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
120+
this->listener = listener;
121+
return *this;
122+
}
94123

95124
/// Whether this should fold while greedily rewriting.
96125
bool fold = true;
126+
GreedyRewriteConfig &setFold(bool enable = true) {
127+
fold = enable;
128+
return *this;
129+
}
97130

98131
/// If set to "true", constants are CSE'd (even across multiple regions that
99132
/// are in a parent-ancestor relationship).
100133
bool cseConstants = true;
134+
GreedyRewriteConfig &setCSEConstants(bool enable = true) {
135+
cseConstants = enable;
136+
return *this;
137+
}
101138
};
102139

103140
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
127127
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128128
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129129
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130-
GreedyRewriteConfig config;
131-
config.listener =
132-
static_cast<RewriterBase::Listener *>(rewriter.getListener());
133-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
134130
// Apply the simplification pattern to a fixpoint.
135-
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
131+
if (failed(applyOpPatternsGreedily(
132+
targets, frozenPatterns,
133+
GreedyRewriteConfig()
134+
.setListener(
135+
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
136+
.setStrictMode(GreedyRewriteStrictness::ExistingAndNewOps)))) {
136137
auto diag = emitDefiniteFailure()
137138
<< "affine.min/max simplification did not converge";
138139
return diag;

mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
237237
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
238238
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
239239
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
240-
GreedyRewriteConfig config;
241-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
242-
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
240+
(void)applyOpPatternsGreedily(
241+
copyOps, frozenPatterns,
242+
GreedyRewriteConfig().setStrictMode(
243+
GreedyRewriteStrictness::ExistingAndNewOps));
243244
}

mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
109109
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
110110
opsToSimplify.push_back(op);
111111
});
112-
GreedyRewriteConfig config;
113-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
114-
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
112+
(void)applyOpPatternsGreedily(
113+
opsToSimplify, frozenPatterns,
114+
GreedyRewriteConfig().setStrictMode(
115+
GreedyRewriteStrictness::ExistingAndNewOps));
115116
}

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
315315
// Simplify/canonicalize the affine.for.
316316
RewritePatternSet patterns(res.getContext());
317317
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
318-
GreedyRewriteConfig config;
319-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
320318
bool erased;
321-
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
322-
config, /*changed=*/nullptr, &erased);
319+
(void)applyOpPatternsGreedily(
320+
res.getOperation(), std::move(patterns),
321+
GreedyRewriteConfig().setStrictMode(
322+
GreedyRewriteStrictness::ExistingAndNewOps),
323+
/*changed=*/nullptr, &erased);
323324
if (!erased && !prologue)
324325
prologue = res;
325326
if (!erased)

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
494494
RewritePatternSet patterns(ctx);
495495
populateIntRangeOptimizationsPatterns(patterns, solver);
496496

497-
GreedyRewriteConfig config;
498-
config.listener = &listener;
499-
500-
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
497+
if (failed(applyPatternsGreedily(
498+
op, std::move(patterns),
499+
GreedyRewriteConfig().setListener(&listener))))
501500
signalPassFailure();
502501
}
503502
};
@@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
520519
RewritePatternSet patterns(ctx);
521520
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
522521

523-
GreedyRewriteConfig config;
524522
// We specifically need bottom-up traversal as cmpi pattern needs range
525523
// data, attached to its original argument values.
526-
config.useTopDownTraversal = false;
527-
config.listener = &listener;
528-
529-
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
524+
if (failed(applyPatternsGreedily(
525+
op, std::move(patterns),
526+
GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
527+
&listener))))
530528
signalPassFailure();
531529
}
532530
};

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,15 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466+
467+
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
466468
// We don't want that the block structure changes invalidating the
467-
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
469+
// `BufferOriginAnalysis` so we apply the rewrites with `Normal` level of
468470
// region simplification
469-
GreedyRewriteConfig config;
470-
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
471-
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
472-
473-
if (failed(
474-
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
471+
if (failed(applyPatternsGreedily(
472+
getOperation(), std::move(patterns),
473+
GreedyRewriteConfig().setEnableRegionSimplification(
474+
GreedySimplifyRegionLevel::Normal))))
475475
signalPassFailure();
476476
}
477477
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3587,9 +3587,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
35873587
vector::populateVectorStepLoweringPatterns(patterns);
35883588

35893589
TrackingListener listener(state, *this);
3590-
GreedyRewriteConfig config;
3591-
config.listener = &listener;
3592-
if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
3590+
if (failed(
3591+
applyPatternsGreedily(target, std::move(patterns),
3592+
GreedyRewriteConfig().setListener(&listener))))
35933593
return emitDefaultDefiniteFailure(target);
35943594

35953595
results.push_back(target);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,10 +2327,9 @@ struct LinalgElementwiseOpFusionPass
23272327
// Add constant folding patterns.
23282328
populateConstantFoldLinalgOperations(patterns, defaultControlFn);
23292329

2330-
// Use TopDownTraversal for compile time reasons
2331-
GreedyRewriteConfig grc;
2332-
grc.useTopDownTraversal = true;
2333-
(void)applyPatternsGreedily(op, std::move(patterns), grc);
2330+
// Use TopDownTraversal for compile time reasons.
2331+
(void)applyPatternsGreedily(op, std::move(patterns),
2332+
GreedyRewriteConfig().setUseTopDownTraversal());
23342333
}
23352334
};
23362335

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,10 +1438,10 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14381438
if (!patterns)
14391439
return success();
14401440

1441-
GreedyRewriteConfig config;
1442-
config.listener = this;
1443-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1444-
return applyOpPatternsGreedily(ops, patterns.value(), config);
1441+
return applyOpPatternsGreedily(
1442+
ops, patterns.value(),
1443+
GreedyRewriteConfig().setListener(this).setStrictMode(
1444+
GreedyRewriteStrictness::ExistingAndNewOps));
14451445
}
14461446

14471447
void SliceTrackingListener::notifyOperationInserted(

0 commit comments

Comments
 (0)