Skip to content

Commit 5eae529

Browse files
[mlir][IR] Add rewriter API for moving operations
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
1 parent 55cb52b commit 5eae529

File tree

21 files changed

+135
-43
lines changed

21 files changed

+135
-43
lines changed

mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/include/mlir/IR/Builders.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ class Builder {
205205
/// automatically inserted at an insertion point. The builder is copyable.
206206
class OpBuilder : public Builder {
207207
public:
208+
class InsertPoint;
208209
struct Listener;
209210

210211
/// Create a builder with the given context.
@@ -285,12 +286,17 @@ class OpBuilder : public Builder {
285286

286287
virtual ~Listener() = default;
287288

288-
/// Notification handler for when an operation is inserted into the builder.
289-
/// `op` is the operation that was inserted.
290-
virtual void notifyOperationInserted(Operation *op) {}
291-
292-
/// Notification handler for when a block is created using the builder.
293-
/// `block` is the block that was created.
289+
/// Notify the listener that the specified operation was inserted.
290+
///
291+
/// * If the operation was moved, then `previous` is the previous location
292+
/// of the op.
293+
/// * If the operation was unlinked before it was inserted, then `previous`
294+
/// is empty.
295+
///
296+
/// Note: Creating an (unlinked) op does not trigger this notification.
297+
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
298+
299+
/// Notify the listener that the specified block was inserted.
294300
virtual void notifyBlockCreated(Block *block) {}
295301

296302
protected:
@@ -517,7 +523,7 @@ class OpBuilder : public Builder {
517523
if (succeeded(tryFold(op, results)))
518524
op->erase();
519525
else if (listener)
520-
listener->notifyOperationInserted(op);
526+
listener->notifyOperationInserted(op, /*previous=*/{});
521527
}
522528

523529
/// Overload to create or fold a single result operation.

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
428428

429429
/// Notify the listener that the specified operation is about to be erased.
430430
/// At this point, the operation has zero uses.
431+
///
432+
/// Note: This notification is not triggered when unlinking an operation.
431433
virtual void notifyOperationRemoved(Operation *op) {}
432434

433435
/// Notify the listener that the pattern failed to match the given
@@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
450452
struct ForwardingListener : public RewriterBase::Listener {
451453
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
452454

453-
void notifyOperationInserted(Operation *op) override {
454-
listener->notifyOperationInserted(op);
455+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
456+
listener->notifyOperationInserted(op, previous);
455457
}
456458
void notifyBlockCreated(Block *block) override {
457459
listener->notifyBlockCreated(block);
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
591593
/// block into a new block, and return it.
592594
virtual Block *splitBlock(Block *block, Block::iterator before);
593595

596+
/// Unlink this operation from its current block and insert it right before
597+
/// `existingOp` which may be in the same or another block in the same
598+
/// function.
599+
void moveOpBefore(Operation *op, Operation *existingOp);
600+
601+
/// Unlink this operation from its current block and insert it right before
602+
/// `iterator` in the specified block.
603+
virtual void moveOpBefore(Operation *op, Block *block,
604+
Block::iterator iterator);
605+
606+
/// Unlink this operation from its current block and insert it right after
607+
/// `existingOp` which may be in the same or another block in the same
608+
/// function.
609+
void moveOpAfter(Operation *op, Operation *existingOp);
610+
611+
/// Unlink this operation from its current block and insert it right after
612+
/// `iterator` in the specified block.
613+
virtual void moveOpAfter(Operation *op, Block *block,
614+
Block::iterator iterator);
615+
594616
/// This method is used to notify the rewriter that an in-place operation
595617
/// modification is about to happen. A call to this function *must* be
596618
/// followed by a call to either `finalizeOpModification` or

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
737737
using PatternRewriter::cloneRegionBefore;
738738

739739
/// PatternRewriter hook for inserting a new operation.
740-
void notifyOperationInserted(Operation *op) override;
740+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
741741

742742
/// PatternRewriter hook for updating the given operation in-place.
743743
/// Note: These methods only track updates to the given operation itself,
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
761761
detail::ConversionPatternRewriterImpl &getImpl();
762762

763763
private:
764+
// Hide unsupported pattern rewriter API.
764765
using OpBuilder::getListener;
765766
using OpBuilder::setListener;
766767

768+
void moveOpBefore(Operation *op, Block *block,
769+
Block::iterator iterator) override;
770+
void moveOpAfter(Operation *op, Block *block,
771+
Block::iterator iterator) override;
772+
767773
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
768774
};
769775

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12061206
if (failed(applyOp->fold(constOperands, foldResults)) ||
12071207
foldResults.empty()) {
12081208
if (OpBuilder::Listener *listener = b.getListener())
1209-
listener->notifyOperationInserted(applyOp);
1209+
listener->notifyOperationInserted(applyOp, /*previous=*/{});
12101210
return applyOp.getResult();
12111211
}
12121212

@@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
12741274
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
12751275
foldResults.empty()) {
12761276
if (OpBuilder::Listener *listener = b.getListener())
1277-
listener->notifyOperationInserted(minMaxOp);
1277+
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
12781278
return minMaxOp.getResult();
12791279
}
12801280

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
273273
// Insert function into the module symbol table and assign it unique name.
274274
SymbolTable symbolTable(module);
275275
symbolTable.insert(func);
276-
rewriter.getListener()->notifyOperationInserted(func);
276+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
277277

278278
// Create function entry block.
279279
Block *block =
@@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
489489
// Insert function into the module symbol table and assign it unique name.
490490
SymbolTable symbolTable(module);
491491
symbolTable.insert(func);
492-
rewriter.getListener()->notifyOperationInserted(func);
492+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
493493

494494
// Create function entry block.
495495
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
371371
toMemrefOps.erase(op);
372372
}
373373

374-
void notifyOperationInserted(Operation *op) override {
374+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
375+
// We only care about newly created ops.
376+
if (previous.isSet())
377+
return;
378+
375379
erasedOps.erase(op);
376380

377381
// Gather statistics about allocs.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
214214
}
215215

216216
private:
217-
void notifyOperationInserted(Operation *op) override {
218-
ForwardingListener::notifyOperationInserted(op);
217+
void notifyOperationInserted(Operation *op,
218+
OpBuilder::InsertPoint previous) override {
219+
ForwardingListener::notifyOperationInserted(op, previous);
220+
// We only care about newly created ops.
221+
if (previous.isSet())
222+
return;
219223
auto inserted = newOps.insert(op);
220224
(void)inserted;
221225
assert(inserted.second && "expected newly created op");

mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
8383

8484
// Inline for-loop body operations into 'after' region.
8585
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86-
arg.moveBefore(afterBlock, afterBlock->end());
86+
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
8787

8888
// Add incremented IV to yield operations
8989
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
983983
for (Operation *user : srcBuffer->getUsers()) {
984984
if (hasEffect<MemoryEffects::Free>(user)) {
985985
if (user->getBlock() == parallelCombiningParent->getBlock())
986-
user->moveBefore(user->getBlock()->getTerminator());
986+
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
987987
break;
988988
}
989989
}

mlir/lib/IR/Builders.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
412412
block->getOperations().insert(insertPoint, op);
413413

414414
if (listener)
415-
listener->notifyOperationInserted(op);
415+
listener->notifyOperationInserted(op, /*previous=*/{});
416416
return op;
417417
}
418418

@@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
530530
// about any ops that got inserted inside those regions as part of cloning.
531531
if (listener) {
532532
auto walkFn = [&](Operation *walkedOp) {
533-
listener->notifyOperationInserted(walkedOp);
533+
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
534534
};
535535
for (Region &region : newOp->getRegions())
536536
region.walk<WalkOrder::PreOrder>(walkFn);

mlir/lib/IR/PatternMatch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
366366
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
367367
cloneRegionBefore(region, *before->getParent(), before->getIterator());
368368
}
369+
370+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
371+
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
372+
}
373+
374+
void RewriterBase::moveOpBefore(Operation *op, Block *block,
375+
Block::iterator iterator) {
376+
Block *currentBlock = op->getBlock();
377+
Block::iterator currentIterator = op->getIterator();
378+
op->moveBefore(block, iterator);
379+
if (listener)
380+
listener->notifyOperationInserted(
381+
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
382+
}
383+
384+
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
385+
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
386+
}
387+
388+
void RewriterBase::moveOpAfter(Operation *op, Block *block,
389+
Block::iterator iterator) {
390+
Block *currentBlock = op->getBlock();
391+
Block::iterator currentIterator = op->getIterator();
392+
op->moveAfter(block, iterator);
393+
if (listener)
394+
listener->notifyOperationInserted(
395+
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
396+
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
16021602
Block *cloned = mapping.lookup(&b);
16031603
impl->notifyCreatedBlock(cloned);
16041604
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
1605-
[&](Operation *op) { notifyOperationInserted(op); });
1605+
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
16061606
}
16071607
}
16081608

1609-
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1609+
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
1610+
InsertPoint previous) {
1611+
assert(!previous.isSet() && "expected newly created op");
16101612
LLVM_DEBUG({
16111613
impl->logger.startLine()
16121614
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
@@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
16511653
return impl->notifyMatchFailure(loc, reasonCallback);
16521654
}
16531655

1656+
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
1657+
Block::iterator iterator) {
1658+
llvm_unreachable(
1659+
"moving single ops is not supported in a dialect conversion");
1660+
}
1661+
1662+
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
1663+
Block::iterator iterator) {
1664+
llvm_unreachable(
1665+
"moving single ops is not supported in a dialect conversion");
1666+
}
1667+
16541668
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
16551669
return *impl;
16561670
}

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
133133
}
134134
}
135135

136-
void notifyOperationInserted(Operation *op) override {
137-
RewriterBase::ForwardingListener::notifyOperationInserted(op);
136+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
137+
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
138+
// Invalidate the finger print of the op that owns the block into which the
139+
// op was inserted into.
138140
invalidateFingerPrint(op->getParentOp());
141+
142+
// Also invalidate the finger print of the op that owns the block from which
143+
// the op was moved from. (Only applicable if the op was moved.)
144+
if (previous.isSet())
145+
invalidateFingerPrint(previous.getBlock()->getParentOp());
139146
}
140147

141148
void notifyOperationModified(Operation *op) override {
@@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
331338
/// Notify the driver that the specified operation was inserted. Update the
332339
/// worklist as needed: The operation is enqueued depending on scope and
333340
/// strict mode.
334-
void notifyOperationInserted(Operation *op) override;
341+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
335342

336343
/// Notify the driver that the specified operation was removed. Update the
337344
/// worklist as needed: The operation and its children are removed from the
@@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
641648
config.listener->notifyBlockRemoved(block);
642649
}
643650

644-
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
651+
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
652+
InsertPoint previous) {
645653
LLVM_DEBUG({
646654
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
647655
<< ")\n";
648656
});
649657
if (config.listener)
650-
config.listener->notifyOperationInserted(op);
658+
config.listener->notifyOperationInserted(op, previous);
651659
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
652660
strictModeFilteredOps.insert(op);
653661
addToWorklist(op);

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
365365
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
366366
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
367367
OpResult newLoopResult = loopLike.getLoopResults()->back();
368-
extractionOp->moveBefore(loopLike);
369-
insertionOp->moveAfter(loopLike);
368+
rewriter.moveOpBefore(extractionOp, loopLike);
369+
rewriter.moveOpAfter(insertionOp, loopLike);
370370
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
371371
insertionOp.getDestinationOperand().get());
372372
extractionOp.getSourceOperand().set(

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
159159
auto ifOp =
160160
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
161161
// True branch.
162-
op->moveBefore(&ifOp.getThenRegion().front(),
163-
ifOp.getThenRegion().front().begin());
162+
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
163+
ifOp.getThenRegion().front().begin());
164164
rewriter.setInsertionPointAfter(op);
165165
if (op->getNumResults() > 0)
166166
rewriter.create<scf::YieldOp>(loc, op->getResults());

0 commit comments

Comments
 (0)