Skip to content

Commit

Permalink
[mlir] Enable decoupling two kinds of greedy behavior. (#104649)
Browse files Browse the repository at this point in the history
The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.

These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.

Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.

For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
  • Loading branch information
jpienaar authored Dec 20, 2024
1 parent 412e1af commit 09dfc57
Show file tree
Hide file tree
Showing 110 changed files with 313 additions and 246 deletions.
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class InlineElementalsPass
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR elemental inlining");
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ class LowerHLFIRIntrinsics
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
module, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR intrinsic lowering");
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ class OptimizedBufferizationPass
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR optimized bufferization");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class SimplifyHLFIRIntrinsics
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR intrinsic simplification");
Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ struct AlgebraicSimplification
void AlgebraicSimplification::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateMathAlgebraicSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}

std::unique_ptr<mlir::Pass> fir::createAlgebraicSimplificationPass() {
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class AssumedRankOpConversion
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ class ConstantArgumentGlobalisationOpt
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;

patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
mod, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() {
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
std::move(patterns), config))) {
if (mlir::failed(mlir::applyOpPatternsGreedily(
opsToConvert, std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ which point the driver finishes.

This driver comes in two fashions:

* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
* `applyPatternsGreedily` ("region-based driver") applies patterns to
all ops in a given region or a given container op (but not the container op
itself). I.e., the worklist is initialized with all containing ops.
* `applyOpPatternsAndFold` ("op-based driver") applies patterns to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class StandaloneSwitchBarFoo
RewritePatternSet patterns(&getContext());
patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
if (failed(applyPatternsGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};
Expand Down
72 changes: 55 additions & 17 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ class GreedyRewriteConfig {

/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;

/// Whether this should fold while greedily rewriting.
bool fold = true;

/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool cseConstants = true;
};

//===----------------------------------------------------------------------===//
Expand All @@ -104,8 +111,8 @@ class GreedyRewriteConfig {
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// A region scope can be set in the configuration parameter. By default, the
/// scope is set to the specified region. Only in-scope ops are added to the
Expand All @@ -117,10 +124,20 @@ class GreedyRewriteConfig {
///
/// Note: This method does not apply patterns to the region's parent operation.
LogicalResult
applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(region, patterns, config, changed);
}

/// Rewrite ops nested under the given operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
Expand All @@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// This overload runs a separate greedy rewrite for each region of the
/// specified op. A region scope can be set in the configuration parameter. By
Expand All @@ -147,57 +164,78 @@ applyPatternsAndFoldGreedily(Region &region,
///
/// Note: This method does not apply patterns to the given operation itself.
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
bool anyRegionChanged = false;
bool failed = false;
for (Region &region : op->getRegions()) {
bool regionChanged;
failed |=
applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
.failed();
failed |= applyPatternsGreedily(region, patterns, config, &regionChanged)
.failed();
anyRegionChanged |= regionChanged;
}
if (changed)
*changed = anyRegionChanged;
return failure(failed);
}
/// Same as `applyPatternsGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(op, patterns, config, changed);
}

/// Rewrite the specified ops by repeatedly applying the highest benefit
/// patterns in a greedy worklist driven manner until a fixpoint is reached.
///
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// Newly created ops and other pre-existing ops that use results of rewritten
/// ops or supply operands to such ops are also processed, unless such ops are
/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
/// regardless of `strictMode`).
///
/// In addition to strictness, a region scope can be specified. Only ops within
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
/// the scope are simplified. This is similar to `applyPatternsGreedily`,
/// where only ops within the given region/op are simplified by default. If no
/// scope is specified, it is assumed to be the first common enclosing region of
/// the given ops.
///
/// Note that ops in `ops` could be erased as result of folding, becoming dead,
/// or via pattern rewrites. If more far reaching simplification is desired,
/// `applyPatternsAndFoldGreedily` should be used.
/// `applyPatternsGreedily` should be used.
///
/// Returns "success" if the iterative process converged (i.e., fixpoint was
/// reached) and no more patterns can be matched. `changed` is set to "true" if
/// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
/// `ops` were erased.
LogicalResult
applyOpPatternsGreedily(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
/// Same as `applyOpPatternsGreedily` with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead",
"applyOpPatternsGreedily")
inline LogicalResult
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
bool *changed = nullptr, bool *allErased = nullptr) {
config.fold = true;
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}

} // namespace mlir

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
arith::populateArithToArmSMEConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr
RewritePatternSet patterns(context);
populateConvertArmNeon2dToIntrPatterns(patterns);

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
arith::populateExpandBFloat16Patterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
(void)applyPatternsGreedily(m, std::move(patterns));
}

LLVMTypeConverter converter(ctx, options);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
ctx);

(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns));
(void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ConvertShapeConstraints
RewritePatternSet patterns(context);
populateConvertShapeConstraintsConversionPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
if (failed(applyPatternsGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorToArmSMEPatterns(patterns, getContext());

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<Pass> mlir::createConvertVectorToArmSMEPass() {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();

IRRewriter rewriter(&getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

// Convert to the LLVM IR dialect.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass
RewritePatternSet lowerTransferPatterns(&getContext());
mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
lowerTransferPatterns);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(lowerTransferPatterns));
(void)applyPatternsGreedily(getOperation(),
std::move(lowerTransferPatterns));

RewritePatternSet patterns(&getContext());
populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() {
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
}
Loading

0 comments on commit 09dfc57

Please sign in to comment.