Skip to content

Commit 8a2015b

Browse files
[mlir][IR] Add rewriter API for moving operations
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
1 parent 55cb52b commit 8a2015b

File tree

13 files changed

+85
-19
lines changed

13 files changed

+85
-19
lines changed

mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/include/mlir/IR/Builders.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,18 @@ class OpBuilder : public Builder {
285285

286286
virtual ~Listener() = default;
287287

288-
/// Notification handler for when an operation is inserted into the builder.
289-
/// `op` is the operation that was inserted.
288+
/// Notify the listener that the specified operation was inserted.
289+
///
290+
/// Note: Creating an (unlinked) op does not trigger this notification.
291+
/// Only when the op is inserted, this notification is triggered. This
292+
/// notification is also triggered when moving an operation to a different
293+
/// location.
294+
// TODO: If needed, the previous location of the operation could be passed
295+
// as a parameter. This would also allow listeners to distinguish between
296+
// "newly created op was inserted" and "existing op was moved".
290297
virtual void notifyOperationInserted(Operation *op) {}
291298

292-
/// Notification handler for when a block is created using the builder.
293-
/// `block` is the block that was created.
299+
/// Notify the listener that the specified block was inserted.
294300
virtual void notifyBlockCreated(Block *block) {}
295301

296302
protected:

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
428428

429429
/// Notify the listener that the specified operation is about to be erased.
430430
/// At this point, the operation has zero uses.
431+
///
432+
/// Note: This notification is not triggered when unlinking an operation.
431433
virtual void notifyOperationRemoved(Operation *op) {}
432434

433435
/// Notify the listener that the pattern failed to match the given
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
591593
/// block into a new block, and return it.
592594
virtual Block *splitBlock(Block *block, Block::iterator before);
593595

596+
/// Unlink this operation from its current block and insert it right before
597+
/// `existingOp` which may be in the same or another block in the same
598+
/// function.
599+
void moveOpBefore(Operation *op, Operation *existingOp);
600+
601+
/// Unlink this operation from its current block and insert it right before
602+
/// `iterator` in the specified block.
603+
virtual void moveOpBefore(Operation *op, Block *block,
604+
Block::iterator iterator);
605+
606+
/// Unlink this operation from its current block and insert it right after
607+
/// `existingOp` which may be in the same or another block in the same
608+
/// function.
609+
void moveOpAfter(Operation *op, Operation *existingOp);
610+
611+
/// Unlink this operation from its current block and insert it right after
612+
/// `iterator` in the specified block.
613+
virtual void moveOpAfter(Operation *op, Block *block,
614+
Block::iterator iterator);
615+
594616
/// This method is used to notify the rewriter that an in-place operation
595617
/// modification is about to happen. A call to this function *must* be
596618
/// followed by a call to either `finalizeOpModification` or

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
761761
detail::ConversionPatternRewriterImpl &getImpl();
762762

763763
private:
764+
// Hide unsupported pattern rewriter API.
764765
using OpBuilder::getListener;
765766
using OpBuilder::setListener;
766767

768+
void moveOpBefore(Operation *op, Block *block,
769+
Block::iterator iterator) override;
770+
void moveOpAfter(Operation *op, Block *block,
771+
Block::iterator iterator) override;
772+
767773
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
768774
};
769775

mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
8383

8484
// Inline for-loop body operations into 'after' region.
8585
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86-
arg.moveBefore(afterBlock, afterBlock->end());
86+
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
8787

8888
// Add incremented IV to yield operations
8989
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
983983
for (Operation *user : srcBuffer->getUsers()) {
984984
if (hasEffect<MemoryEffects::Free>(user)) {
985985
if (user->getBlock() == parallelCombiningParent->getBlock())
986-
user->moveBefore(user->getBlock()->getTerminator());
986+
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
987987
break;
988988
}
989989
}

mlir/lib/IR/PatternMatch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,25 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
366366
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
367367
cloneRegionBefore(region, *before->getParent(), before->getIterator());
368368
}
369+
370+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
371+
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
372+
}
373+
374+
void RewriterBase::moveOpBefore(Operation *op, Block *block,
375+
Block::iterator iterator) {
376+
op->moveBefore(block, iterator);
377+
if (listener)
378+
listener->notifyOperationInserted(op);
379+
}
380+
381+
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
382+
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
383+
}
384+
385+
void RewriterBase::moveOpAfter(Operation *op, Block *block,
386+
Block::iterator iterator) {
387+
op->moveAfter(block, iterator);
388+
if (listener)
389+
listener->notifyOperationInserted(op);
390+
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
16511651
return impl->notifyMatchFailure(loc, reasonCallback);
16521652
}
16531653

1654+
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
1655+
Block::iterator iterator) {
1656+
llvm_unreachable(
1657+
"moving single ops is not supported in a dialect conversion");
1658+
}
1659+
1660+
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
1661+
Block::iterator iterator) {
1662+
llvm_unreachable(
1663+
"moving single ops is not supported in a dialect conversion");
1664+
}
1665+
16541666
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
16551667
return *impl;
16561668
}

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
365365
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
366366
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
367367
OpResult newLoopResult = loopLike.getLoopResults()->back();
368-
extractionOp->moveBefore(loopLike);
369-
insertionOp->moveAfter(loopLike);
368+
rewriter.moveOpBefore(extractionOp, loopLike);
369+
rewriter.moveOpAfter(insertionOp, loopLike);
370370
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
371371
insertionOp.getDestinationOperand().get());
372372
extractionOp.getSourceOperand().set(

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
159159
auto ifOp =
160160
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
161161
// True branch.
162-
op->moveBefore(&ifOp.getThenRegion().front(),
163-
ifOp.getThenRegion().front().begin());
162+
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
163+
ifOp.getThenRegion().front().begin());
164164
rewriter.setInsertionPointAfter(op);
165165
if (op->getNumResults() > 0)
166166
rewriter.create<scf::YieldOp>(loc, op->getResults());

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
193193
return failure();
194194
if (!toBeHoisted->hasAttr("eligible"))
195195
return failure();
196-
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
197-
// op is modified.
198-
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
196+
rewriter.moveOpBefore(toBeHoisted, op);
199197
return success();
200198
}
201199
};

0 commit comments

Comments
 (0)