Skip to content

Commit 5cc0f76

Browse files
[mlir][IR] Add rewriter API for moving operations (#78988)
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This allows listeners such as the greedy pattern rewrite driver to react to IR changes. This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to.
1 parent 45fec0c commit 5cc0f76

File tree

20 files changed

+138
-41
lines changed

20 files changed

+138
-41
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
490490
LLVM_DUMP_METHOD void dumpFunc();
491491

492492
/// FirOpBuilder hook for creating new operation.
493-
void notifyOperationInserted(mlir::Operation *op) override {
493+
void notifyOperationInserted(mlir::Operation *op,
494+
mlir::OpBuilder::InsertPoint previous) override {
495+
// We only care about newly created operations.
496+
if (previous.isSet())
497+
return;
494498
setCommonAttributes(op);
495499
}
496500

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
730730
HLFIRListener(fir::FirOpBuilder &builder,
731731
mlir::ConversionPatternRewriter &rewriter)
732732
: builder{builder}, rewriter{rewriter} {}
733-
void notifyOperationInserted(mlir::Operation *op) override {
734-
builder.notifyOperationInserted(op);
735-
rewriter.notifyOperationInserted(op);
733+
void notifyOperationInserted(mlir::Operation *op,
734+
mlir::OpBuilder::InsertPoint previous) override {
735+
builder.notifyOperationInserted(op, previous);
736+
rewriter.notifyOperationInserted(op, previous);
736737
}
737738
virtual void notifyBlockCreated(mlir::Block *block) override {
738739
builder.notifyBlockCreated(block);

mlir/include/mlir/IR/Builders.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ class Builder {
205205
/// automatically inserted at an insertion point. The builder is copyable.
206206
class OpBuilder : public Builder {
207207
public:
208+
class InsertPoint;
208209
struct Listener;
209210

210211
/// Create a builder with the given context.
@@ -285,12 +286,17 @@ class OpBuilder : public Builder {
285286

286287
virtual ~Listener() = default;
287288

288-
/// Notification handler for when an operation is inserted into the builder.
289-
/// `op` is the operation that was inserted.
290-
virtual void notifyOperationInserted(Operation *op) {}
291-
292-
/// Notification handler for when a block is created using the builder.
293-
/// `block` is the block that was created.
289+
/// Notify the listener that the specified operation was inserted.
290+
///
291+
/// * If the operation was moved, then `previous` is the previous location
292+
/// of the op.
293+
/// * If the operation was unlinked before it was inserted, then `previous`
294+
/// is empty.
295+
///
296+
/// Note: Creating an (unlinked) op does not trigger this notification.
297+
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
298+
299+
/// Notify the listener that the specified block was inserted.
294300
virtual void notifyBlockCreated(Block *block) {}
295301

296302
protected:
@@ -517,7 +523,7 @@ class OpBuilder : public Builder {
517523
if (succeeded(tryFold(op, results)))
518524
op->erase();
519525
else if (listener)
520-
listener->notifyOperationInserted(op);
526+
listener->notifyOperationInserted(op, /*previous=*/{});
521527
}
522528

523529
/// Overload to create or fold a single result operation.

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
428428

429429
/// Notify the listener that the specified operation is about to be erased.
430430
/// At this point, the operation has zero uses.
431+
///
432+
/// Note: This notification is not triggered when unlinking an operation.
431433
virtual void notifyOperationRemoved(Operation *op) {}
432434

433435
/// Notify the listener that the pattern failed to match the given
@@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
450452
struct ForwardingListener : public RewriterBase::Listener {
451453
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
452454

453-
void notifyOperationInserted(Operation *op) override {
454-
listener->notifyOperationInserted(op);
455+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
456+
listener->notifyOperationInserted(op, previous);
455457
}
456458
void notifyBlockCreated(Block *block) override {
457459
listener->notifyBlockCreated(block);
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
591593
/// block into a new block, and return it.
592594
virtual Block *splitBlock(Block *block, Block::iterator before);
593595

596+
/// Unlink this operation from its current block and insert it right before
597+
/// `existingOp` which may be in the same or another block in the same
598+
/// function.
599+
void moveOpBefore(Operation *op, Operation *existingOp);
600+
601+
/// Unlink this operation from its current block and insert it right before
602+
/// `iterator` in the specified block.
603+
virtual void moveOpBefore(Operation *op, Block *block,
604+
Block::iterator iterator);
605+
606+
/// Unlink this operation from its current block and insert it right after
607+
/// `existingOp` which may be in the same or another block in the same
608+
/// function.
609+
void moveOpAfter(Operation *op, Operation *existingOp);
610+
611+
/// Unlink this operation from its current block and insert it right after
612+
/// `iterator` in the specified block.
613+
virtual void moveOpAfter(Operation *op, Block *block,
614+
Block::iterator iterator);
615+
594616
/// This method is used to notify the rewriter that an in-place operation
595617
/// modification is about to happen. A call to this function *must* be
596618
/// followed by a call to either `finalizeOpModification` or

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
737737
using PatternRewriter::cloneRegionBefore;
738738

739739
/// PatternRewriter hook for inserting a new operation.
740-
void notifyOperationInserted(Operation *op) override;
740+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
741741

742742
/// PatternRewriter hook for updating the given operation in-place.
743743
/// Note: These methods only track updates to the given operation itself,
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
761761
detail::ConversionPatternRewriterImpl &getImpl();
762762

763763
private:
764+
// Hide unsupported pattern rewriter API.
764765
using OpBuilder::getListener;
765766
using OpBuilder::setListener;
766767

768+
void moveOpBefore(Operation *op, Block *block,
769+
Block::iterator iterator) override;
770+
void moveOpAfter(Operation *op, Block *block,
771+
Block::iterator iterator) override;
772+
767773
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
768774
};
769775

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12061206
if (failed(applyOp->fold(constOperands, foldResults)) ||
12071207
foldResults.empty()) {
12081208
if (OpBuilder::Listener *listener = b.getListener())
1209-
listener->notifyOperationInserted(applyOp);
1209+
listener->notifyOperationInserted(applyOp, /*previous=*/{});
12101210
return applyOp.getResult();
12111211
}
12121212

@@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
12741274
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
12751275
foldResults.empty()) {
12761276
if (OpBuilder::Listener *listener = b.getListener())
1277-
listener->notifyOperationInserted(minMaxOp);
1277+
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
12781278
return minMaxOp.getResult();
12791279
}
12801280

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
273273
// Insert function into the module symbol table and assign it unique name.
274274
SymbolTable symbolTable(module);
275275
symbolTable.insert(func);
276-
rewriter.getListener()->notifyOperationInserted(func);
276+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
277277

278278
// Create function entry block.
279279
Block *block =
@@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
489489
// Insert function into the module symbol table and assign it unique name.
490490
SymbolTable symbolTable(module);
491491
symbolTable.insert(func);
492-
rewriter.getListener()->notifyOperationInserted(func);
492+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
493493

494494
// Create function entry block.
495495
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
371371
toMemrefOps.erase(op);
372372
}
373373

374-
void notifyOperationInserted(Operation *op) override {
374+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
375+
// We only care about newly created ops.
376+
if (previous.isSet())
377+
return;
378+
375379
erasedOps.erase(op);
376380

377381
// Gather statistics about allocs.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
214214
}
215215

216216
private:
217-
void notifyOperationInserted(Operation *op) override {
218-
ForwardingListener::notifyOperationInserted(op);
217+
void notifyOperationInserted(Operation *op,
218+
OpBuilder::InsertPoint previous) override {
219+
ForwardingListener::notifyOperationInserted(op, previous);
220+
// We only care about newly created ops.
221+
if (previous.isSet())
222+
return;
219223
auto inserted = newOps.insert(op);
220224
(void)inserted;
221225
assert(inserted.second && "expected newly created op");

mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
8383

8484
// Inline for-loop body operations into 'after' region.
8585
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86-
arg.moveBefore(afterBlock, afterBlock->end());
86+
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
8787

8888
// Add incremented IV to yield operations
8989
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
983983
for (Operation *user : srcBuffer->getUsers()) {
984984
if (hasEffect<MemoryEffects::Free>(user)) {
985985
if (user->getBlock() == parallelCombiningParent->getBlock())
986-
user->moveBefore(user->getBlock()->getTerminator());
986+
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
987987
break;
988988
}
989989
}

mlir/lib/IR/Builders.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
412412
block->getOperations().insert(insertPoint, op);
413413

414414
if (listener)
415-
listener->notifyOperationInserted(op);
415+
listener->notifyOperationInserted(op, /*previous=*/{});
416416
return op;
417417
}
418418

@@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
530530
// about any ops that got inserted inside those regions as part of cloning.
531531
if (listener) {
532532
auto walkFn = [&](Operation *walkedOp) {
533-
listener->notifyOperationInserted(walkedOp);
533+
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
534534
};
535535
for (Region &region : newOp->getRegions())
536536
region.walk<WalkOrder::PreOrder>(walkFn);

mlir/lib/IR/PatternMatch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
366366
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
367367
cloneRegionBefore(region, *before->getParent(), before->getIterator());
368368
}
369+
370+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
371+
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
372+
}
373+
374+
void RewriterBase::moveOpBefore(Operation *op, Block *block,
375+
Block::iterator iterator) {
376+
Block *currentBlock = op->getBlock();
377+
Block::iterator currentIterator = op->getIterator();
378+
op->moveBefore(block, iterator);
379+
if (listener)
380+
listener->notifyOperationInserted(
381+
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
382+
}
383+
384+
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
385+
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
386+
}
387+
388+
void RewriterBase::moveOpAfter(Operation *op, Block *block,
389+
Block::iterator iterator) {
390+
Block *currentBlock = op->getBlock();
391+
Block::iterator currentIterator = op->getIterator();
392+
op->moveAfter(block, iterator);
393+
if (listener)
394+
listener->notifyOperationInserted(
395+
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
396+
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
16021602
Block *cloned = mapping.lookup(&b);
16031603
impl->notifyCreatedBlock(cloned);
16041604
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
1605-
[&](Operation *op) { notifyOperationInserted(op); });
1605+
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
16061606
}
16071607
}
16081608

1609-
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
1609+
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
1610+
InsertPoint previous) {
1611+
assert(!previous.isSet() && "expected newly created op");
16101612
LLVM_DEBUG({
16111613
impl->logger.startLine()
16121614
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
@@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
16511653
return impl->notifyMatchFailure(loc, reasonCallback);
16521654
}
16531655

1656+
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
1657+
Block::iterator iterator) {
1658+
llvm_unreachable(
1659+
"moving single ops is not supported in a dialect conversion");
1660+
}
1661+
1662+
void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
1663+
Block::iterator iterator) {
1664+
llvm_unreachable(
1665+
"moving single ops is not supported in a dialect conversion");
1666+
}
1667+
16541668
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
16551669
return *impl;
16561670
}

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
133133
}
134134
}
135135

136-
void notifyOperationInserted(Operation *op) override {
137-
RewriterBase::ForwardingListener::notifyOperationInserted(op);
136+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
137+
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
138+
// Invalidate the finger print of the op that owns the block into which the
139+
// op was inserted into.
138140
invalidateFingerPrint(op->getParentOp());
141+
142+
// Also invalidate the finger print of the op that owns the block from which
143+
// the op was moved from. (Only applicable if the op was moved.)
144+
if (previous.isSet())
145+
invalidateFingerPrint(previous.getBlock()->getParentOp());
139146
}
140147

141148
void notifyOperationModified(Operation *op) override {
@@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
331338
/// Notify the driver that the specified operation was inserted. Update the
332339
/// worklist as needed: The operation is enqueued depending on scope and
333340
/// strict mode.
334-
void notifyOperationInserted(Operation *op) override;
341+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
335342

336343
/// Notify the driver that the specified operation was removed. Update the
337344
/// worklist as needed: The operation and its children are removed from the
@@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
641648
config.listener->notifyBlockRemoved(block);
642649
}
643650

644-
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
651+
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
652+
InsertPoint previous) {
645653
LLVM_DEBUG({
646654
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
647655
<< ")\n";
648656
});
649657
if (config.listener)
650-
config.listener->notifyOperationInserted(op);
658+
config.listener->notifyOperationInserted(op, previous);
651659
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
652660
strictModeFilteredOps.insert(op);
653661
addToWorklist(op);

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
365365
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
366366
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
367367
OpResult newLoopResult = loopLike.getLoopResults()->back();
368-
extractionOp->moveBefore(loopLike);
369-
insertionOp->moveAfter(loopLike);
368+
rewriter.moveOpBefore(extractionOp, loopLike);
369+
rewriter.moveOpAfter(insertionOp, loopLike);
370370
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
371371
insertionOp.getDestinationOperand().get());
372372
extractionOp.getSourceOperand().set(

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
159159
auto ifOp =
160160
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
161161
// True branch.
162-
op->moveBefore(&ifOp.getThenRegion().front(),
163-
ifOp.getThenRegion().front().begin());
162+
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
163+
ifOp.getThenRegion().front().begin());
164164
rewriter.setInsertionPointAfter(op);
165165
if (op->getNumResults() > 0)
166166
rewriter.create<scf::YieldOp>(loc, op->getResults());

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
193193
return failure();
194194
if (!toBeHoisted->hasAttr("eligible"))
195195
return failure();
196-
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
197-
// op is modified.
198-
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
196+
rewriter.moveOpBefore(toBeHoisted, op);
199197
return success();
200198
}
201199
};

mlir/test/lib/IR/TestClone.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ using namespace mlir;
1515
namespace {
1616

1717
struct DumpNotifications : public OpBuilder::Listener {
18-
void notifyOperationInserted(Operation *op) override {
18+
void notifyOperationInserted(Operation *op,
19+
OpBuilder::InsertPoint previous) override {
1920
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
2021
}
2122
};

0 commit comments

Comments
 (0)