From 391cb541223bb0d41620eb5e25c107563dc3e12c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 23 Dec 2022 13:01:00 +0100 Subject: [PATCH] [mlir] Add option to limit number of pattern rewrites in CanonicalizerPass The greedy pattern rewriter consists of two nested loops. `config.maxIterations` (which configurable on the CanonicalizerPass) controls the maximum number of iterations of the outer loop. ``` /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. Use `kNoLimit` to /// disable this iteration limit. int64_t maxIterations = 10; ``` This change adds `config.maxNumRewrites` which controls the maximum number of pattern rewrites within an iteration. (It effectively control the maximum number of iterations of the inner loop.) This flag is meant for debugging and useful in cases where one or multiple faulty patterns can be applied indefinitely, resulting in an infinite loop. Differential Revision: https://reviews.llvm.org/D140525 --- .../Transforms/GreedyPatternRewriteDriver.h | 10 +++++++--- mlir/include/mlir/Transforms/Passes.td | 6 ++++-- mlir/lib/Transforms/Canonicalizer.cpp | 2 ++ .../Utils/GreedyPatternRewriteDriver.cpp | 18 +++++++++++------- mlir/test/Pass/run-reproducer.mlir | 4 ++-- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index d9d272110b310..5478587dcc43d 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -33,11 +33,15 @@ class GreedyRewriteConfig { bool enableRegionSimplification = true; /// This specifies the maximum number of times the rewriter will iterate - /// between applying patterns and simplifying regions. Use `kNoIterationLimit` - /// to disable this iteration limit. + /// between applying patterns and simplifying regions. Use `kNoLimit` to + /// disable this iteration limit. int64_t maxIterations = 10; - static constexpr int64_t kNoIterationLimit = -1; + /// This specifies the maximum number of rewrites within an iteration. Use + /// `kNoLimit` to disable this limit. + int64_t maxNumRewrites = kNoLimit; + + static constexpr int64_t kNoLimit = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index d45f5f08b3008..e7d122323ae33 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -30,10 +30,12 @@ def Canonicalizer : Pass<"canonicalize"> { "Seed the worklist in general top-down order">, Option<"enableRegionSimplification", "region-simplify", "bool", /*default=*/"true", - "Seed the worklist in general top-down order">, + "Perform control flow optimizations to the region tree">, Option<"maxIterations", "max-iterations", "int64_t", /*default=*/"10", - "Seed the worklist in general top-down order"> + "Max. iterations between applying patterns / simplifying regions">, + Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1", + "Max. number of pattern rewrites within an iteration"> ] # RewritePassUtils.options; } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index a4215629a964e..dc3bf97b32388 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase { this->topDownProcessingEnabled = config.useTopDownTraversal; this->enableRegionSimplification = config.enableRegionSimplification; this->maxIterations = config.maxIterations; + this->maxNumRewrites = config.maxNumRewrites; this->disabledPatterns = disabledPatterns; this->enabledPatterns = enabledPatterns; } @@ -55,6 +56,7 @@ struct Canonicalizer : public impl::CanonicalizerBase { config.useTopDownTraversal = topDownProcessingEnabled; config.enableRegionSimplification = enableRegionSimplification; config.maxIterations = maxIterations; + config.maxNumRewrites = maxNumRewrites; (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 935ca2eb93740..0d6fdaf3039cf 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -183,6 +183,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { SmallVector originalOperands, resultValues; changed = false; + int64_t numRewrites = 0; while (!worklist.empty()) { auto *op = popFromWorklist(); @@ -279,16 +280,20 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); #endif - changed |= succeeded(matchResult); + if (succeeded(matchResult)) { + changed = true; + if (numRewrites++ >= config.maxNumRewrites && + config.maxNumRewrites != GreedyRewriteConfig::kNoLimit) + break; + } } // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && - (iteration++ < config.maxIterations || - config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); + } while (changed && (iteration++ < config.maxIterations || + config.maxIterations == GreedyRewriteConfig::kNoLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; @@ -506,9 +511,8 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, changed |= succeeded(matcher.matchAndRewrite(op, *this)); if ((erased = opErasedViaPatternRewrites)) return success(); - } while (changed && - (++iterations < maxIterations || - maxIterations == GreedyRewriteConfig::kNoIterationLimit)); + } while (changed && (++iterations < maxIterations || + maxIterations == GreedyRewriteConfig::kNoLimit)); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return failure(changed); diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir index 496471d032a52..3a958f8a92509 100644 --- a/mlir/test/Pass/run-reproducer.mlir +++ b/mlir/test/Pass/run-reproducer.mlir @@ -14,8 +14,8 @@ func.func @bar() { external_resources: { mlir_reproducer: { verify_each: true, - // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 region-simplify=false top-down=false})) - pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 region-simplify=false top-down=false}))", + // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false})) + pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))", disable_threading: true } }