Skip to content

Commit 0c61b24

Browse files
authored
[mlir] add a fluent API to GreedyRewriterConfig (#137122)
This is similar to other configuration objects used across MLIR. Rename some fields to better reflect that they are no longer booleans. Reland 04d2611 / #132253.
1 parent 15bb1db commit 0c61b24

30 files changed

+224
-168
lines changed

flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ class LowerRepackArraysPass
357357
patterns.insert<PackArrayConversion>(context);
358358
patterns.insert<UnpackArrayConversion>(context);
359359
mlir::GreedyRewriteConfig config;
360-
config.enableRegionSimplification =
361-
mlir::GreedySimplifyRegionLevel::Disabled;
360+
config.setRegionSimplificationLevel(
361+
mlir::GreedySimplifyRegionLevel::Disabled);
362362
(void)applyPatternsGreedily(module, std::move(patterns), config);
363363
}
364364

flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ class InlineElementalsPass
119119

120120
mlir::GreedyRewriteConfig config;
121121
// Prevent the pattern driver from merging blocks.
122-
config.enableRegionSimplification =
123-
mlir::GreedySimplifyRegionLevel::Disabled;
122+
config.setRegionSimplificationLevel(
123+
mlir::GreedySimplifyRegionLevel::Disabled);
124124

125125
mlir::RewritePatternSet patterns(context);
126126
patterns.insert<InlineElementalConversion>(context);

flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class InlineHLFIRAssignPass
135135

136136
mlir::GreedyRewriteConfig config;
137137
// Prevent the pattern driver from merging blocks.
138-
config.enableRegionSimplification =
139-
mlir::GreedySimplifyRegionLevel::Disabled;
138+
config.setRegionSimplificationLevel(
139+
mlir::GreedySimplifyRegionLevel::Disabled);
140140

141141
mlir::RewritePatternSet patterns(context);
142142
patterns.insert<InlineHLFIRAssignConversion>(context);

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ class LowerHLFIRIntrinsics
557557
// Pattern rewriting only requires that the resulting IR is still valid
558558
mlir::GreedyRewriteConfig config;
559559
// Prevent the pattern driver from merging blocks
560-
config.enableRegionSimplification =
561-
mlir::GreedySimplifyRegionLevel::Disabled;
560+
config.setRegionSimplificationLevel(
561+
mlir::GreedySimplifyRegionLevel::Disabled);
562562

563563
if (mlir::failed(
564564
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,8 @@ class OptimizedBufferizationPass
875875

876876
mlir::GreedyRewriteConfig config;
877877
// Prevent the pattern driver from merging blocks
878-
config.enableRegionSimplification =
879-
mlir::GreedySimplifyRegionLevel::Disabled;
878+
config.setRegionSimplificationLevel(
879+
mlir::GreedySimplifyRegionLevel::Disabled);
880880

881881
mlir::RewritePatternSet patterns(context);
882882
// TODO: right now the patterns are non-conflicting,

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2132,8 +2132,8 @@ class SimplifyHLFIRIntrinsics
21322132

21332133
mlir::GreedyRewriteConfig config;
21342134
// Prevent the pattern driver from merging blocks
2135-
config.enableRegionSimplification =
2136-
mlir::GreedySimplifyRegionLevel::Disabled;
2135+
config.setRegionSimplificationLevel(
2136+
mlir::GreedySimplifyRegionLevel::Disabled);
21372137

21382138
mlir::RewritePatternSet patterns(context);
21392139
patterns.insert<TransposeAsElementalConversion>(context);

flang/lib/Optimizer/Passes/Pipelines.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ void addNestedPassToAllTopLevelOperationsConditionally(
3535

3636
void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
3737
mlir::GreedyRewriteConfig config;
38-
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
38+
config.setRegionSimplificationLevel(
39+
mlir::GreedySimplifyRegionLevel::Disabled);
3940
pm.addPass(mlir::createCanonicalizerPass(config));
4041
}
4142

@@ -163,7 +164,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
163164

164165
// simplify the IR
165166
mlir::GreedyRewriteConfig config;
166-
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
167+
config.setRegionSimplificationLevel(
168+
mlir::GreedySimplifyRegionLevel::Disabled);
167169
pm.addPass(mlir::createCSEPass());
168170
fir::addAVC(pm, pc.OptLevel);
169171
addNestedPassToAllTopLevelOperations<PassConstructor>(

flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ class AssumedRankOpConversion
152152
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
153153
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
154154
mlir::GreedyRewriteConfig config;
155-
config.enableRegionSimplification =
156-
mlir::GreedySimplifyRegionLevel::Disabled;
155+
config.setRegionSimplificationLevel(
156+
mlir::GreedySimplifyRegionLevel::Disabled);
157157
(void)applyPatternsGreedily(mod, std::move(patterns), config);
158158
}
159159
};

flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ class ConstantArgumentGlobalisationOpt
168168
auto *context = &getContext();
169169
mlir::RewritePatternSet patterns(context);
170170
mlir::GreedyRewriteConfig config;
171-
config.enableRegionSimplification =
172-
mlir::GreedySimplifyRegionLevel::Disabled;
173-
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
171+
config.setRegionSimplificationLevel(
172+
mlir::GreedySimplifyRegionLevel::Disabled);
173+
config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);
174174

175175
patterns.insert<CallOpRewriter>(context, *di);
176176
if (mlir::failed(

flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ void SimplifyFIROperationsPass::runOnOperation() {
205205
fir::populateSimplifyFIROperationsPatterns(patterns,
206206
preferInlineImplementation);
207207
mlir::GreedyRewriteConfig config;
208-
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
208+
config.setRegionSimplificationLevel(
209+
mlir::GreedySimplifyRegionLevel::Disabled);
209210

210211
if (mlir::failed(
211212
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {

flang/lib/Optimizer/Transforms/StackArrays.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,8 @@ void StackArraysPass::runOnOperation() {
806806
mlir::RewritePatternSet patterns(&context);
807807
mlir::GreedyRewriteConfig config;
808808
// prevent the pattern driver form merging blocks
809-
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
809+
config.setRegionSimplificationLevel(
810+
mlir::GreedySimplifyRegionLevel::Disabled);
810811

811812
patterns.insert<AllocMemConversion>(&context, *candidateOps);
812813
if (mlir::failed(mlir::applyOpPatternsGreedily(

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

+63-13
Original file line numberDiff line numberDiff line change
@@ -49,33 +49,55 @@ class GreedyRewriteConfig {
4949
/// larger patterns when given an ambiguous pattern set.
5050
///
5151
/// Note: Only applicable when simplifying entire regions.
52-
bool useTopDownTraversal = false;
52+
bool getUseTopDownTraversal() const { return useTopDownTraversal; }
53+
GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
54+
useTopDownTraversal = use;
55+
return *this;
56+
}
5357

5458
/// Perform control flow optimizations to the region tree after applying all
5559
/// patterns.
5660
///
5761
/// Note: Only applicable when simplifying entire regions.
58-
GreedySimplifyRegionLevel enableRegionSimplification =
59-
GreedySimplifyRegionLevel::Aggressive;
62+
GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
63+
return regionSimplificationLevel;
64+
}
65+
GreedyRewriteConfig &
66+
setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
67+
regionSimplificationLevel = level;
68+
return *this;
69+
}
6070

6171
/// This specifies the maximum number of times the rewriter will iterate
6272
/// between applying patterns and simplifying regions. Use `kNoLimit` to
6373
/// disable this iteration limit.
6474
///
6575
/// Note: Only applicable when simplifying entire regions.
66-
int64_t maxIterations = 10;
76+
int64_t getMaxIterations() const { return maxIterations; }
77+
GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
78+
maxIterations = iterations;
79+
return *this;
80+
}
6781

6882
/// This specifies the maximum number of rewrites within an iteration. Use
6983
/// `kNoLimit` to disable this limit.
70-
int64_t maxNumRewrites = kNoLimit;
84+
int64_t getMaxNumRewrites() const { return maxNumRewrites; }
85+
GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
86+
maxNumRewrites = limit;
87+
return *this;
88+
}
7189

7290
static constexpr int64_t kNoLimit = -1;
7391

7492
/// Only ops within the scope are added to the worklist. If no scope is
7593
/// specified, the closest enclosing region around the initial list of ops
7694
/// (or the specified region, depending on which greedy rewrite entry point
7795
/// is used) is used as a scope.
78-
Region *scope = nullptr;
96+
Region *getScope() const { return scope; }
97+
GreedyRewriteConfig &setScope(Region *scope) {
98+
this->scope = scope;
99+
return *this;
100+
}
79101

80102
/// Strict mode can restrict the ops that are added to the worklist during
81103
/// the rewrite.
@@ -87,16 +109,44 @@ class GreedyRewriteConfig {
87109
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
88110
/// were on the worklist at the very beginning) enqueued. All other ops are
89111
/// excluded.
90-
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
112+
GreedyRewriteStrictness getStrictness() const { return strictness; }
113+
GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
114+
strictness = mode;
115+
return *this;
116+
}
91117

92118
/// An optional listener that should be notified about IR modifications.
93-
RewriterBase::Listener *listener = nullptr;
119+
RewriterBase::Listener *getListener() const { return listener; }
120+
GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
121+
this->listener = listener;
122+
return *this;
123+
}
94124

95125
/// Whether this should fold while greedily rewriting.
96-
bool fold = true;
126+
bool isFoldingEnabled() const { return fold; }
127+
GreedyRewriteConfig &enableFolding(bool enable = true) {
128+
fold = enable;
129+
return *this;
130+
}
97131

98132
/// If set to "true", constants are CSE'd (even across multiple regions that
99133
/// are in a parent-ancestor relationship).
134+
bool isConstantCSEEnabled() const { return cseConstants; }
135+
GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
136+
cseConstants = enable;
137+
return *this;
138+
}
139+
140+
private:
141+
Region *scope = nullptr;
142+
bool useTopDownTraversal = false;
143+
GreedySimplifyRegionLevel regionSimplificationLevel =
144+
GreedySimplifyRegionLevel::Aggressive;
145+
int64_t maxIterations = 10;
146+
int64_t maxNumRewrites = kNoLimit;
147+
GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
148+
RewriterBase::Listener *listener = nullptr;
149+
bool fold = true;
100150
bool cseConstants = true;
101151
};
102152

@@ -128,14 +178,14 @@ applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
128178
GreedyRewriteConfig config = GreedyRewriteConfig(),
129179
bool *changed = nullptr);
130180
/// Same as `applyPatternsAndGreedily` above with folding.
131-
/// FIXME: Remove this once transition to above is complieted.
181+
/// FIXME: Remove this once transition to above is completed.
132182
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
133183
inline LogicalResult
134184
applyPatternsAndFoldGreedily(Region &region,
135185
const FrozenRewritePatternSet &patterns,
136186
GreedyRewriteConfig config = GreedyRewriteConfig(),
137187
bool *changed = nullptr) {
138-
config.fold = true;
188+
config.enableFolding();
139189
return applyPatternsGreedily(region, patterns, config, changed);
140190
}
141191

@@ -187,7 +237,7 @@ applyPatternsAndFoldGreedily(Operation *op,
187237
const FrozenRewritePatternSet &patterns,
188238
GreedyRewriteConfig config = GreedyRewriteConfig(),
189239
bool *changed = nullptr) {
190-
config.fold = true;
240+
config.enableFolding();
191241
return applyPatternsGreedily(op, patterns, config, changed);
192242
}
193243

@@ -233,7 +283,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
233283
const FrozenRewritePatternSet &patterns,
234284
GreedyRewriteConfig config = GreedyRewriteConfig(),
235285
bool *changed = nullptr, bool *allErased = nullptr) {
236-
config.fold = true;
286+
config.enableFolding();
237287
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
238288
}
239289

mlir/include/mlir/Transforms/Passes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
3333
Option<"topDownProcessingEnabled", "top-down", "bool",
3434
/*default=*/"true",
3535
"Seed the worklist in general top-down order">,
36-
Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
36+
Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
3737
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
3838
"Perform control flow optimizations to the region tree",
3939
[{::llvm::cl::values(

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
127127
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128128
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129129
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130-
GreedyRewriteConfig config;
131-
config.listener =
132-
static_cast<RewriterBase::Listener *>(rewriter.getListener());
133-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
134130
// Apply the simplification pattern to a fixpoint.
135-
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
131+
if (failed(applyOpPatternsGreedily(
132+
targets, frozenPatterns,
133+
GreedyRewriteConfig()
134+
.setListener(
135+
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
136+
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
136137
auto diag = emitDefiniteFailure()
137138
<< "affine.min/max simplification did not converge";
138139
return diag;

mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
237237
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
238238
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
239239
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
240-
GreedyRewriteConfig config;
241-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
242-
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
240+
(void)applyOpPatternsGreedily(
241+
copyOps, frozenPatterns,
242+
GreedyRewriteConfig().setStrictness(
243+
GreedyRewriteStrictness::ExistingAndNewOps));
243244
}

mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
109109
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
110110
opsToSimplify.push_back(op);
111111
});
112-
GreedyRewriteConfig config;
113-
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
114-
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
112+
(void)applyOpPatternsGreedily(
113+
opsToSimplify, frozenPatterns,
114+
GreedyRewriteConfig().setStrictness(
115+
GreedyRewriteStrictness::ExistingAndNewOps));
115116
}

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
315315
// Simplify/canonicalize the affine.for.
316316
RewritePatternSet patterns(res.getContext());
317317
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
318-
GreedyRewriteConfig config;
319-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
320318
bool erased;
321-
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
322-
config, /*changed=*/nullptr, &erased);
319+
(void)applyOpPatternsGreedily(
320+
res.getOperation(), std::move(patterns),
321+
GreedyRewriteConfig().setStrictness(
322+
GreedyRewriteStrictness::ExistingAndNewOps),
323+
/*changed=*/nullptr, &erased);
323324
if (!erased && !prologue)
324325
prologue = res;
325326
if (!erased)

mlir/lib/Dialect/Affine/Utils/Utils.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
426426
RewritePatternSet patterns(ifOp.getContext());
427427
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
428428
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
429-
GreedyRewriteConfig config;
430-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
431429
bool erased;
432-
(void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
433-
/*changed=*/nullptr, &erased);
430+
(void)applyOpPatternsGreedily(
431+
ifOp.getOperation(), frozenPatterns,
432+
GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
433+
/*changed=*/nullptr, &erased);
434434
if (erased) {
435435
if (folded)
436436
*folded = true;

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

+7-9
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
494494
RewritePatternSet patterns(ctx);
495495
populateIntRangeOptimizationsPatterns(patterns, solver);
496496

497-
GreedyRewriteConfig config;
498-
config.listener = &listener;
499-
500-
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
497+
if (failed(applyPatternsGreedily(
498+
op, std::move(patterns),
499+
GreedyRewriteConfig().setListener(&listener))))
501500
signalPassFailure();
502501
}
503502
};
@@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
520519
RewritePatternSet patterns(ctx);
521520
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
522521

523-
GreedyRewriteConfig config;
524522
// We specifically need bottom-up traversal as cmpi pattern needs range
525523
// data, attached to its original argument values.
526-
config.useTopDownTraversal = false;
527-
config.listener = &listener;
528-
529-
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
524+
if (failed(applyPatternsGreedily(
525+
op, std::move(patterns),
526+
GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
527+
&listener))))
530528
signalPassFailure();
531529
}
532530
};

0 commit comments

Comments
 (0)