Skip to content

Commit 3ed98cb

Browse files
[mlir][IR] Change notifyBlockCreated to notifyBlockInserted (#79472)
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also #78988. This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
1 parent fb8eb42 commit 3ed98cb

File tree

8 files changed

+64
-61
lines changed

8 files changed

+64
-61
lines changed

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
735735
builder.notifyOperationInserted(op, previous);
736736
rewriter.notifyOperationInserted(op, previous);
737737
}
738-
virtual void notifyBlockCreated(mlir::Block *block) override {
739-
builder.notifyBlockCreated(block);
740-
rewriter.notifyBlockCreated(block);
738+
virtual void notifyBlockInserted(mlir::Block *block, mlir::Region *previous,
739+
mlir::Region::iterator previousIt) override {
740+
builder.notifyBlockInserted(block, previous, previousIt);
741+
rewriter.notifyBlockInserted(block, previous, previousIt);
741742
}
742743
fir::FirOpBuilder &builder;
743744
mlir::ConversionPatternRewriter &rewriter;

mlir/include/mlir/IR/Builders.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,15 @@ class OpBuilder : public Builder {
297297
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
298298

299299
/// Notify the listener that the specified block was inserted.
300-
virtual void notifyBlockCreated(Block *block) {}
300+
///
301+
/// * If the block was moved, then `previous` and `previousIt` are the
302+
/// previous location of the block.
303+
/// * If the block was unlinked before it was inserted, then `previous`
304+
/// is "nullptr".
305+
///
306+
/// Note: Creating an (unlinked) block does not trigger this notification.
307+
virtual void notifyBlockInserted(Block *block, Region *previous,
308+
Region::iterator previousIt) {}
301309

302310
protected:
303311
Listener(Kind kind) : ListenerBase(kind) {}

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,9 @@ class RewriterBase : public OpBuilder {
455455
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
456456
listener->notifyOperationInserted(op, previous);
457457
}
458-
void notifyBlockCreated(Block *block) override {
459-
listener->notifyBlockCreated(block);
458+
void notifyBlockInserted(Block *block, Region *previous,
459+
Region::iterator previousIt) override {
460+
listener->notifyBlockInserted(block, previous, previousIt);
460461
}
461462
void notifyBlockRemoved(Block *block) override {
462463
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -495,8 +496,8 @@ class RewriterBase : public OpBuilder {
495496
/// another region "parent". The two regions must be different. The caller
496497
/// is responsible for creating or updating the operation transferring flow
497498
/// of control to the region and passing it the correct block arguments.
498-
virtual void inlineRegionBefore(Region &region, Region &parent,
499-
Region::iterator before);
499+
void inlineRegionBefore(Region &region, Region &parent,
500+
Region::iterator before);
500501
void inlineRegionBefore(Region &region, Block *before);
501502

502503
/// Clone the blocks that belong to "region" before the given position in

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ class ConversionPatternRewriter final : public PatternRewriter,
713713
void eraseBlock(Block *block) override;
714714

715715
/// PatternRewriter hook creating a new block.
716-
void notifyBlockCreated(Block *block) override;
716+
void notifyBlockInserted(Block *block, Region *previous,
717+
Region::iterator previousIt) override;
717718

718719
/// PatternRewriter hook for splitting a block into two parts.
719720
Block *splitBlock(Block *block, Block::iterator before) override;
@@ -723,11 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
723724
ValueRange argValues = std::nullopt) override;
724725
using PatternRewriter::inlineBlockBefore;
725726

726-
/// PatternRewriter hook for moving blocks out of a region.
727-
void inlineRegionBefore(Region &region, Region &parent,
728-
Region::iterator before) override;
729-
using PatternRewriter::inlineRegionBefore;
730-
731727
/// PatternRewriter hook for cloning blocks of one region into another. The
732728
/// given region to clone *must* not have been modified as part of conversion
733729
/// yet, i.e. it must be within an operation that is either in the process of

mlir/lib/IR/Builders.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
429429
setInsertionPointToEnd(b);
430430

431431
if (listener)
432-
listener->notifyBlockCreated(b);
432+
listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
433433
return b;
434434
}
435435

mlir/lib/IR/PatternMatch.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,18 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
343343
/// region and pass it the correct block arguments.
344344
void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
345345
Region::iterator before) {
346-
parent.getBlocks().splice(before, region.getBlocks());
346+
// Fast path: If no listener is attached, move all blocks at once.
347+
if (!listener) {
348+
parent.getBlocks().splice(before, region.getBlocks());
349+
return;
350+
}
351+
352+
// Move blocks from the beginning of the region one-by-one.
353+
while (!region.empty()) {
354+
Block *block = &region.front();
355+
parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
356+
listener->notifyBlockInserted(block, &region, region.begin());
357+
}
347358
}
348359
void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
349360
inlineRegionBefore(region, *before->getParent(), before->getIterator());

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ enum class BlockActionKind {
250250
};
251251

252252
/// Original position of the given block in its parent region. During undo
253-
/// actions, the block needs to be placed after `insertAfterBlock`.
253+
/// actions, the block needs to be placed before `insertBeforeBlock`.
254254
struct BlockPosition {
255255
Region *region;
256-
Block *insertAfterBlock;
256+
Block *insertBeforeBlock;
257257
};
258258

259259
/// Information needed to undo inlining actions.
@@ -910,7 +910,8 @@ struct ConversionPatternRewriterImpl {
910910
void notifyBlockIsBeingErased(Block *block);
911911

912912
/// Notifies that a block was created.
913-
void notifyCreatedBlock(Block *block);
913+
void notifyInsertedBlock(Block *block, Region *previous,
914+
Region::iterator previousIt);
914915

915916
/// Notifies that a block was split.
916917
void notifySplitBlock(Block *block, Block *continuation);
@@ -919,10 +920,6 @@ struct ConversionPatternRewriterImpl {
919920
void notifyBlockBeingInlined(Block *block, Block *srcBlock,
920921
Block::iterator before);
921922

922-
/// Notifies that the blocks of a region are about to be moved.
923-
void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
924-
Region::iterator before);
925-
926923
/// Notifies that a pattern match failed for the given reason.
927924
LogicalResult
928925
notifyMatchFailure(Location loc,
@@ -1173,10 +1170,9 @@ void ConversionPatternRewriterImpl::undoBlockActions(
11731170
// Put the block (owned by action) back into its original position.
11741171
case BlockActionKind::Erase: {
11751172
auto &blockList = action.originalPosition.region->getBlocks();
1176-
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1177-
blockList.insert((insertAfterBlock
1178-
? std::next(Region::iterator(insertAfterBlock))
1179-
: blockList.begin()),
1173+
Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
1174+
blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
1175+
: blockList.end()),
11801176
action.block);
11811177
break;
11821178
}
@@ -1196,10 +1192,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
11961192
// Move the block back to its original position.
11971193
case BlockActionKind::Move: {
11981194
Region *originalRegion = action.originalPosition.region;
1199-
Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1195+
Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
12001196
originalRegion->getBlocks().splice(
1201-
(insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
1202-
: originalRegion->end()),
1197+
(insertBeforeBlock ? Region::iterator(insertBeforeBlock)
1198+
: originalRegion->end()),
12031199
action.block->getParent()->getBlocks(), action.block);
12041200
break;
12051201
}
@@ -1398,12 +1394,19 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13981394

13991395
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
14001396
Region *region = block->getParent();
1401-
Block *origPrevBlock = block->getPrevNode();
1402-
blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1397+
Block *origNextBlock = block->getNextNode();
1398+
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
14031399
}
14041400

1405-
void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
1406-
blockActions.push_back(BlockAction::getCreate(block));
1401+
void ConversionPatternRewriterImpl::notifyInsertedBlock(
1402+
Block *block, Region *previous, Region::iterator previousIt) {
1403+
if (!previous) {
1404+
// This is a newly created block.
1405+
blockActions.push_back(BlockAction::getCreate(block));
1406+
return;
1407+
}
1408+
Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
1409+
blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
14071410
}
14081411

14091412
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
@@ -1416,19 +1419,6 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
14161419
blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
14171420
}
14181421

1419-
void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
1420-
Region &region, Region &parent, Region::iterator before) {
1421-
if (region.empty())
1422-
return;
1423-
Block *laterBlock = &region.back();
1424-
for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1425-
blockActions.push_back(
1426-
BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
1427-
laterBlock = &earlierBlock;
1428-
}
1429-
blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
1430-
}
1431-
14321422
LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
14331423
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
14341424
LLVM_DEBUG({
@@ -1551,8 +1541,9 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
15511541
results);
15521542
}
15531543

1554-
void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
1555-
impl->notifyCreatedBlock(block);
1544+
void ConversionPatternRewriter::notifyBlockInserted(
1545+
Block *block, Region *previous, Region::iterator previousIt) {
1546+
impl->notifyInsertedBlock(block, previous, previousIt);
15561547
}
15571548

15581549
Block *ConversionPatternRewriter::splitBlock(Block *block,
@@ -1582,13 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
15821573
eraseBlock(source);
15831574
}
15841575

1585-
void ConversionPatternRewriter::inlineRegionBefore(Region &region,
1586-
Region &parent,
1587-
Region::iterator before) {
1588-
impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1589-
PatternRewriter::inlineRegionBefore(region, parent, before);
1590-
}
1591-
15921576
void ConversionPatternRewriter::cloneRegionBefore(Region &region,
15931577
Region &parent,
15941578
Region::iterator before,
@@ -1600,7 +1584,7 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
16001584

16011585
for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
16021586
Block *cloned = mapping.lookup(&b);
1603-
impl->notifyCreatedBlock(cloned);
1587+
impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
16041588
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
16051589
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
16061590
}

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
377377
/// simplifications.
378378
void addOperandsToWorklist(ValueRange operands);
379379

380-
/// Notify the driver that the given block was created.
381-
void notifyBlockCreated(Block *block) override;
380+
/// Notify the driver that the given block was inserted.
381+
void notifyBlockInserted(Block *block, Region *previous,
382+
Region::iterator previousIt) override;
382383

383384
/// Notify the driver that the given block is about to be removed.
384385
void notifyBlockRemoved(Block *block) override;
@@ -638,9 +639,10 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
638639
worklist.push(op);
639640
}
640641

641-
void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
642+
void GreedyPatternRewriteDriver::notifyBlockInserted(
643+
Block *block, Region *previous, Region::iterator previousIt) {
642644
if (config.listener)
643-
config.listener->notifyBlockCreated(block);
645+
config.listener->notifyBlockInserted(block, previous, previousIt);
644646
}
645647

646648
void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {

0 commit comments

Comments
 (0)