Skip to content

Commit 6161d33

Browse files
[mlir][IR] Add rewriter API for moving operations
The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`. After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
1 parent d5c9d40 commit 6161d33

File tree

23 files changed

+144
-47
lines changed

23 files changed

+144
-47
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
490490
LLVM_DUMP_METHOD void dumpFunc();
491491

492492
/// FirOpBuilder hook for creating new operation.
493-
void notifyOperationInserted(mlir::Operation *op) override {
493+
void notifyOperationInserted(mlir::Operation *op,
494+
mlir::OpBuilder::InsertPoint previous) override {
495+
// We only care about newly created operations.
496+
if (!previous.isSet())
497+
return;
494498
setCommonAttributes(op);
495499
}
496500

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
730730
HLFIRListener(fir::FirOpBuilder &builder,
731731
mlir::ConversionPatternRewriter &rewriter)
732732
: builder{builder}, rewriter{rewriter} {}
733-
void notifyOperationInserted(mlir::Operation *op) override {
734-
builder.notifyOperationInserted(op);
735-
rewriter.notifyOperationInserted(op);
733+
void notifyOperationInserted(mlir::Operation *op,
734+
mlir::OpBuilder::InsertPoint previous) override {
735+
builder.notifyOperationInserted(op, previous);
736+
rewriter.notifyOperationInserted(op, previous);
736737
}
737738
virtual void notifyBlockCreated(mlir::Block *block) override {
738739
builder.notifyBlockCreated(block);

mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
6060

6161
// Make sure to allocate at the beginning of the block.
6262
auto *parentBlock = alloc->getBlock();
63-
alloc->moveBefore(&parentBlock->front());
63+
rewriter.moveOpBefore(alloc, &parentBlock->front());
6464

6565
// Make sure to deallocate this alloc at the end of the block. This is fine
6666
// as toy functions have no control flow.
6767
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
68-
dealloc->moveBefore(&parentBlock->back());
68+
rewriter.moveOpBefore(dealloc, &parentBlock->back());
6969
return alloc;
7070
}
7171

mlir/include/mlir/IR/Builders.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ class Builder {
205205
/// automatically inserted at an insertion point. The builder is copyable.
206206
class OpBuilder : public Builder {
207207
public:
208+
class InsertPoint;
208209
struct Listener;
209210

210211
/// Create a builder with the given context.
@@ -285,12 +286,17 @@ class OpBuilder : public Builder {
285286

286287
virtual ~Listener() = default;
287288

288-
/// Notification handler for when an operation is inserted into the builder.
289-
/// `op` is the operation that was inserted.
290-
virtual void notifyOperationInserted(Operation *op) {}
291-
292-
/// Notification handler for when a block is created using the builder.
293-
/// `block` is the block that was created.
289+
/// Notify the listener that the specified operation was inserted.
290+
///
291+
/// * If the operation was moved, then `previous` is the previous location
292+
/// of the op.
293+
/// * If the operation was unlinked before it was inserted, then `previous`
294+
/// is empty.
295+
///
296+
/// Note: Creating an (unlinked) op does not trigger this notification.
297+
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
298+
299+
/// Notify the listener that the specified block was inserted.
294300
virtual void notifyBlockCreated(Block *block) {}
295301

296302
protected:
@@ -517,7 +523,7 @@ class OpBuilder : public Builder {
517523
if (succeeded(tryFold(op, results)))
518524
op->erase();
519525
else if (listener)
520-
listener->notifyOperationInserted(op);
526+
listener->notifyOperationInserted(op, /*previous=*/{});
521527
}
522528

523529
/// Overload to create or fold a single result operation.

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
428428

429429
/// Notify the listener that the specified operation is about to be erased.
430430
/// At this point, the operation has zero uses.
431+
///
432+
/// Note: This notification is not triggered when unlinking an operation.
431433
virtual void notifyOperationRemoved(Operation *op) {}
432434

433435
/// Notify the listener that the pattern failed to match the given
@@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
450452
struct ForwardingListener : public RewriterBase::Listener {
451453
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
452454

453-
void notifyOperationInserted(Operation *op) override {
454-
listener->notifyOperationInserted(op);
455+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
456+
listener->notifyOperationInserted(op, previous);
455457
}
456458
void notifyBlockCreated(Block *block) override {
457459
listener->notifyBlockCreated(block);
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
591593
/// block into a new block, and return it.
592594
virtual Block *splitBlock(Block *block, Block::iterator before);
593595

596+
/// Unlink this operation from its current block and insert it right before
597+
/// `existingOp` which may be in the same or another block in the same
598+
/// function.
599+
void moveOpBefore(Operation *op, Operation *existingOp);
600+
601+
/// Unlink this operation from its current block and insert it right before
602+
/// `iterator` in the specified block.
603+
virtual void moveOpBefore(Operation *op, Block *block,
604+
Block::iterator iterator);
605+
606+
/// Unlink this operation from its current block and insert it right after
607+
/// `existingOp` which may be in the same or another block in the same
608+
/// function.
609+
void moveOpAfter(Operation *op, Operation *existingOp);
610+
611+
/// Unlink this operation from its current block and insert it right after
612+
/// `iterator` in the specified block.
613+
virtual void moveOpAfter(Operation *op, Block *block,
614+
Block::iterator iterator);
615+
594616
/// This method is used to notify the rewriter that an in-place operation
595617
/// modification is about to happen. A call to this function *must* be
596618
/// followed by a call to either `finalizeOpModification` or

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
737737
using PatternRewriter::cloneRegionBefore;
738738

739739
/// PatternRewriter hook for inserting a new operation.
740-
void notifyOperationInserted(Operation *op) override;
740+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
741741

742742
/// PatternRewriter hook for updating the given operation in-place.
743743
/// Note: These methods only track updates to the given operation itself,
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
761761
detail::ConversionPatternRewriterImpl &getImpl();
762762

763763
private:
764+
// Hide unsupported pattern rewriter API.
764765
using OpBuilder::getListener;
765766
using OpBuilder::setListener;
766767

768+
void moveOpBefore(Operation *op, Block *block,
769+
Block::iterator iterator) override;
770+
void moveOpAfter(Operation *op, Block *block,
771+
Block::iterator iterator) override;
772+
767773
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
768774
};
769775

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12061206
if (failed(applyOp->fold(constOperands, foldResults)) ||
12071207
foldResults.empty()) {
12081208
if (OpBuilder::Listener *listener = b.getListener())
1209-
listener->notifyOperationInserted(applyOp);
1209+
listener->notifyOperationInserted(applyOp, /*previous=*/{});
12101210
return applyOp.getResult();
12111211
}
12121212

@@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
12741274
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
12751275
foldResults.empty()) {
12761276
if (OpBuilder::Listener *listener = b.getListener())
1277-
listener->notifyOperationInserted(minMaxOp);
1277+
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
12781278
return minMaxOp.getResult();
12791279
}
12801280

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
273273
// Insert function into the module symbol table and assign it unique name.
274274
SymbolTable symbolTable(module);
275275
symbolTable.insert(func);
276-
rewriter.getListener()->notifyOperationInserted(func);
276+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
277277

278278
// Create function entry block.
279279
Block *block =
@@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
489489
// Insert function into the module symbol table and assign it unique name.
490490
SymbolTable symbolTable(module);
491491
symbolTable.insert(func);
492-
rewriter.getListener()->notifyOperationInserted(func);
492+
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
493493

494494
// Create function entry block.
495495
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),

0 commit comments

Comments
 (0)