Skip to content

[mlir] Enable decoupling two kinds of greedy behavior. #104649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class InlineElementalsPass
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR elemental inlining");
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ class LowerHLFIRIntrinsics
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
module, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR intrinsic lowering");
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ class OptimizedBufferizationPass
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR optimized bufferization");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class SimplifyHLFIRIntrinsics
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR intrinsic simplification");
Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ struct AlgebraicSimplification
void AlgebraicSimplification::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateMathAlgebraicSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}

std::unique_ptr<mlir::Pass> fir::createAlgebraicSimplificationPass() {
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class AssumedRankOpConversion
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ class ConstantArgumentGlobalisationOpt
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;

patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
mod, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() {
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
std::move(patterns), config))) {
if (mlir::failed(mlir::applyOpPatternsGreedily(
opsToConvert, std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ which point the driver finishes.

This driver comes in two fashions:

* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
* `applyPatternsGreedily` ("region-based driver") applies patterns to
all ops in a given region or a given container op (but not the container op
itself). I.e., the worklist is initialized with all containing ops.
* `applyOpPatternsAndFold` ("op-based driver") applies patterns to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class StandaloneSwitchBarFoo
RewritePatternSet patterns(&getContext());
patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
if (failed(applyPatternsGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};
Expand Down
72 changes: 55 additions & 17 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ class GreedyRewriteConfig {

/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;

/// Whether this should fold while greedily rewriting.
bool fold = true;

/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool cseConstants = true;
};

//===----------------------------------------------------------------------===//
Expand All @@ -104,8 +111,8 @@ class GreedyRewriteConfig {
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// A region scope can be set in the configuration parameter. By default, the
/// scope is set to the specified region. Only in-scope ops are added to the
Expand All @@ -117,10 +124,20 @@ class GreedyRewriteConfig {
///
/// Note: This method does not apply patterns to the region's parent operation.
LogicalResult
applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(region, patterns, config, changed);
}

/// Rewrite ops nested under the given operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
Expand All @@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// This overload runs a separate greedy rewrite for each region of the
/// specified op. A region scope can be set in the configuration parameter. By
Expand All @@ -147,57 +164,78 @@ applyPatternsAndFoldGreedily(Region &region,
///
/// Note: This method does not apply patterns to the given operation itself.
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
bool anyRegionChanged = false;
bool failed = false;
for (Region &region : op->getRegions()) {
bool regionChanged;
failed |=
applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
.failed();
failed |= applyPatternsGreedily(region, patterns, config, &regionChanged)
.failed();
anyRegionChanged |= regionChanged;
}
if (changed)
*changed = anyRegionChanged;
return failure(failed);
}
/// Same as `applyPatternsGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(op, patterns, config, changed);
}

/// Rewrite the specified ops by repeatedly applying the highest benefit
/// patterns in a greedy worklist driven manner until a fixpoint is reached.
///
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// Newly created ops and other pre-existing ops that use results of rewritten
/// ops or supply operands to such ops are also processed, unless such ops are
/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
/// regardless of `strictMode`).
///
/// In addition to strictness, a region scope can be specified. Only ops within
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
/// the scope are simplified. This is similar to `applyPatternsGreedily`,
/// where only ops within the given region/op are simplified by default. If no
/// scope is specified, it is assumed to be the first common enclosing region of
/// the given ops.
///
/// Note that ops in `ops` could be erased as result of folding, becoming dead,
/// or via pattern rewrites. If more far reaching simplification is desired,
/// `applyPatternsAndFoldGreedily` should be used.
/// `applyPatternsGreedily` should be used.
///
/// Returns "success" if the iterative process converged (i.e., fixpoint was
/// reached) and no more patterns can be matched. `changed` is set to "true" if
/// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
/// `ops` were erased.
LogicalResult
applyOpPatternsGreedily(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
/// Same as `applyOpPatternsGreedily` with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead",
"applyOpPatternsGreedily")
inline LogicalResult
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
bool *changed = nullptr, bool *allErased = nullptr) {
config.fold = true;
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}

} // namespace mlir

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
arith::populateArithToArmSMEConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr
RewritePatternSet patterns(context);
populateConvertArmNeon2dToIntrPatterns(patterns);

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
arith::populateExpandBFloat16Patterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
(void)applyPatternsGreedily(m, std::move(patterns));
}

LLVMTypeConverter converter(ctx, options);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
ctx);

(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns));
(void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ConvertShapeConstraints
RewritePatternSet patterns(context);
populateConvertShapeConstraintsConversionPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
if (failed(applyPatternsGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorToArmSMEPatterns(patterns, getContext());

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<Pass> mlir::createConvertVectorToArmSMEPass() {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();

IRRewriter rewriter(&getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

// Convert to the LLVM IR dialect.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass
RewritePatternSet lowerTransferPatterns(&getContext());
mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
lowerTransferPatterns);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(lowerTransferPatterns));
(void)applyPatternsGreedily(getOperation(),
std::move(lowerTransferPatterns));

RewritePatternSet patterns(&getContext());
populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() {
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
}
Loading
Loading