Skip to content

Commit

Permalink
[mlir] Add option to limit number of pattern rewrites in Canonicalize…
Browse files Browse the repository at this point in the history
…rPass

The greedy pattern rewriter consists of two nested loops. `config.maxIterations` (which configurable on the CanonicalizerPass) controls the maximum number of iterations of the outer loop.

```
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
int64_t maxIterations = 10;
```

This change adds `config.maxNumRewrites` which controls the maximum number of pattern rewrites within an iteration. (It effectively control the maximum number of iterations of the inner loop.)

This flag is meant for debugging and useful in cases where one or multiple faulty patterns can be applied indefinitely, resulting in an infinite loop.

Differential Revision: https://reviews.llvm.org/D140525
  • Loading branch information
matthias-springer committed Dec 23, 2022
1 parent 4861a58 commit 391cb54
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 14 deletions.
10 changes: 7 additions & 3 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class GreedyRewriteConfig {
bool enableRegionSimplification = true;

/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoIterationLimit`
/// to disable this iteration limit.
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
int64_t maxIterations = 10;

static constexpr int64_t kNoIterationLimit = -1;
/// This specifies the maximum number of rewrites within an iteration. Use
/// `kNoLimit` to disable this limit.
int64_t maxNumRewrites = kNoLimit;

static constexpr int64_t kNoLimit = -1;
};

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def Canonicalizer : Pass<"canonicalize"> {
"Seed the worklist in general top-down order">,
Option<"enableRegionSimplification", "region-simplify", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
"Perform control flow optimizations to the region tree">,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Seed the worklist in general top-down order">
"Max. iterations between applying patterns / simplifying regions">,
Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
"Max. number of pattern rewrites within an iteration">
] # RewritePassUtils.options;
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Transforms/Canonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
this->maxIterations = config.maxIterations;
this->maxNumRewrites = config.maxNumRewrites;
this->disabledPatterns = disabledPatterns;
this->enabledPatterns = enabledPatterns;
}
Expand All @@ -55,6 +56,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
config.useTopDownTraversal = topDownProcessingEnabled;
config.enableRegionSimplification = enableRegionSimplification;
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
(void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
}

Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
SmallVector<Value, 8> originalOperands, resultValues;

changed = false;
int64_t numRewrites = 0;
while (!worklist.empty()) {
auto *op = popFromWorklist();

Expand Down Expand Up @@ -279,16 +280,20 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
#else
LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
#endif
changed |= succeeded(matchResult);
if (succeeded(matchResult)) {
changed = true;
if (numRewrites++ >= config.maxNumRewrites &&
config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
break;
}
}

// After applying patterns, make sure that the CFG of each of the regions
// is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, regions));
} while (changed &&
(iteration++ < config.maxIterations ||
config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
} while (changed && (iteration++ < config.maxIterations ||
config.maxIterations == GreedyRewriteConfig::kNoLimit));

// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return !changed;
Expand Down Expand Up @@ -506,9 +511,8 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
changed |= succeeded(matcher.matchAndRewrite(op, *this));
if ((erased = opErasedViaPatternRewrites))
return success();
} while (changed &&
(++iterations < maxIterations ||
maxIterations == GreedyRewriteConfig::kNoIterationLimit));
} while (changed && (++iterations < maxIterations ||
maxIterations == GreedyRewriteConfig::kNoLimit));

// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return failure(changed);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Pass/run-reproducer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 region-simplify=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 region-simplify=false top-down=false}))",
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
disable_threading: true
}
}
Expand Down

0 comments on commit 391cb54

Please sign in to comment.