Skip to content

[mlir] add a fluent API to GreedyRewriterConfig #137122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ class LowerRepackArraysPass
patterns.insert<PackArrayConversion>(context);
patterns.insert<UnpackArrayConversion>(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(module, std::move(patterns), config);
}

Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ class InlineElementalsPass

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class InlineHLFIRAssignPass

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

mlir::RewritePatternSet patterns(context);
patterns.insert<InlineHLFIRAssignConversion>(context);
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,8 @@ class LowerHLFIRIntrinsics
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,8 @@ class OptimizedBufferizationPass

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2132,8 +2132,8 @@ class SimplifyHLFIRIntrinsics

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void addNestedPassToAllTopLevelOperationsConditionally(

void addCanonicalizerPassWithoutRegionSimplification(mlir::OpPassManager &pm) {
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCanonicalizerPass(config));
}

Expand Down Expand Up @@ -163,7 +164,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,

// simplify the IR
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
pm.addPass(mlir::createCSEPass());
fir::addAVC(pm, pc.OptLevel);
addNestedPassToAllTopLevelOperations<PassConstructor>(
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class AssumedRankOpConversion
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ class ConstantArgumentGlobalisationOpt
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);
config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps);

patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ void SimplifyFIROperationsPass::runOnOperation() {
fir::populateSimplifyFIROperationsPatterns(patterns,
preferInlineImplementation);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,8 @@ void StackArraysPass::runOnOperation() {
mlir::RewritePatternSet patterns(&context);
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.setRegionSimplificationLevel(
mlir::GreedySimplifyRegionLevel::Disabled);

patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsGreedily(
Expand Down
76 changes: 63 additions & 13 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,33 +49,55 @@ class GreedyRewriteConfig {
/// larger patterns when given an ambiguous pattern set.
///
/// Note: Only applicable when simplifying entire regions.
bool useTopDownTraversal = false;
bool getUseTopDownTraversal() const { return useTopDownTraversal; }
GreedyRewriteConfig &setUseTopDownTraversal(bool use = true) {
useTopDownTraversal = use;
return *this;
}

/// Perform control flow optimizations to the region tree after applying all
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
GreedySimplifyRegionLevel enableRegionSimplification =
GreedySimplifyRegionLevel::Aggressive;
GreedySimplifyRegionLevel getRegionSimplificationLevel() const {
return regionSimplificationLevel;
}
GreedyRewriteConfig &
setRegionSimplificationLevel(GreedySimplifyRegionLevel level) {
regionSimplificationLevel = level;
return *this;
}

/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
///
/// Note: Only applicable when simplifying entire regions.
int64_t maxIterations = 10;
int64_t getMaxIterations() const { return maxIterations; }
GreedyRewriteConfig &setMaxIterations(int64_t iterations) {
maxIterations = iterations;
return *this;
}

/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
int64_t maxNumRewrites = kNoLimit;
int64_t getMaxNumRewrites() const { return maxNumRewrites; }
GreedyRewriteConfig &setMaxNumRewrites(int64_t limit) {
maxNumRewrites = limit;
return *this;
}

static constexpr int64_t kNoLimit = -1;

/// Only ops within the scope are added to the worklist. If no scope is
/// specified, the closest enclosing region around the initial list of ops
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
Region *scope = nullptr;
Region *getScope() const { return scope; }
GreedyRewriteConfig &setScope(Region *scope) {
this->scope = scope;
return *this;
}

/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
Expand All @@ -87,16 +109,44 @@ class GreedyRewriteConfig {
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
GreedyRewriteStrictness getStrictness() const { return strictness; }
GreedyRewriteConfig &setStrictness(GreedyRewriteStrictness mode) {
strictness = mode;
return *this;
}

/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;
RewriterBase::Listener *getListener() const { return listener; }
GreedyRewriteConfig &setListener(RewriterBase::Listener *listener) {
this->listener = listener;
return *this;
}

/// Whether this should fold while greedily rewriting.
bool fold = true;
bool isFoldingEnabled() const { return fold; }
GreedyRewriteConfig &enableFolding(bool enable = true) {
fold = enable;
return *this;
}

/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool isConstantCSEEnabled() const { return cseConstants; }
GreedyRewriteConfig &enableConstantCSE(bool enable = true) {
cseConstants = enable;
return *this;
}

private:
Region *scope = nullptr;
bool useTopDownTraversal = false;
GreedySimplifyRegionLevel regionSimplificationLevel =
GreedySimplifyRegionLevel::Aggressive;
int64_t maxIterations = 10;
int64_t maxNumRewrites = kNoLimit;
GreedyRewriteStrictness strictness = GreedyRewriteStrictness::AnyOp;
RewriterBase::Listener *listener = nullptr;
bool fold = true;
bool cseConstants = true;
};

Expand Down Expand Up @@ -128,14 +178,14 @@ applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
/// FIXME: Remove this once transition to above is completed.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
config.enableFolding();
return applyPatternsGreedily(region, patterns, config, changed);
}

Expand Down Expand Up @@ -187,7 +237,7 @@ applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
config.enableFolding();
return applyPatternsGreedily(op, patterns, config, changed);
}

Expand Down Expand Up @@ -233,7 +283,7 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr) {
config.fold = true;
config.enableFolding();
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
Option<"regionSimplifyLevel", "region-simplify", "mlir::GreedySimplifyRegionLevel",
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
"Perform control flow optimizations to the region tree",
[{::llvm::cl::values(
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.listener =
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
if (failed(applyOpPatternsGreedily(
targets, frozenPatterns,
GreedyRewriteConfig()
.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
(void)applyOpPatternsGreedily(
copyOps, frozenPatterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps));
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
(void)applyOpPatternsGreedily(
opsToSimplify, frozenPatterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps));
}
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,12 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
config, /*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(
res.getOperation(), std::move(patterns),
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingAndNewOps),
/*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,11 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
/*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(
ifOp.getOperation(), frozenPatterns,
GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
/*changed=*/nullptr, &erased);
if (erased) {
if (folded)
*folded = true;
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,9 @@ struct IntRangeOptimizationsPass final
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);

GreedyRewriteConfig config;
config.listener = &listener;

if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(
op, std::move(patterns),
GreedyRewriteConfig().setListener(&listener))))
signalPassFailure();
}
};
Expand All @@ -520,13 +519,12 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);

GreedyRewriteConfig config;
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
config.useTopDownTraversal = false;
config.listener = &listener;

if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(
op, std::move(patterns),
GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
&listener))))
signalPassFailure();
}
};
Expand Down
Loading
Loading