Skip to content

[mlir][IR] Add rewriter API for moving operations #78988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
LLVM_DUMP_METHOD void dumpFunc();

/// FirOpBuilder hook for creating new operation.
void notifyOperationInserted(mlir::Operation *op) override {
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
// We only care about newly created operations.
if (previous.isSet())
return;
setCommonAttributes(op);
}

Expand Down
7 changes: 4 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
HLFIRListener(fir::FirOpBuilder &builder,
mlir::ConversionPatternRewriter &rewriter)
: builder{builder}, rewriter{rewriter} {}
void notifyOperationInserted(mlir::Operation *op) override {
builder.notifyOperationInserted(op);
rewriter.notifyOperationInserted(op);
void notifyOperationInserted(mlir::Operation *op,
mlir::OpBuilder::InsertPoint previous) override {
builder.notifyOperationInserted(op, previous);
rewriter.notifyOperationInserted(op, previous);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);
Expand Down
20 changes: 13 additions & 7 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class Builder {
/// automatically inserted at an insertion point. The builder is copyable.
class OpBuilder : public Builder {
public:
class InsertPoint;
struct Listener;

/// Create a builder with the given context.
Expand Down Expand Up @@ -285,12 +286,17 @@ class OpBuilder : public Builder {

virtual ~Listener() = default;

/// Notification handler for when an operation is inserted into the builder.
/// `op` is the operation that was inserted.
virtual void notifyOperationInserted(Operation *op) {}

/// Notification handler for when a block is created using the builder.
/// `block` is the block that was created.
/// Notify the listener that the specified operation was inserted.
///
/// * If the operation was moved, then `previous` is the previous location
/// of the op.
/// * If the operation was unlinked before it was inserted, then `previous`
/// is empty.
///
/// Note: Creating an (unlinked) op does not trigger this notification.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is similar to what was proposed here (IR listeners), but the callback is fired after the IR was modified.

  // This method is called when an operation is inserted into a block. The oldBlock is nullptr is the operation wasn't previously in a block.
  virtual void notifyOpInserted(Operation *op, Block *oldBlock,
                                Block *newBlock) {}


/// Notify the listener that the specified block was inserted.
virtual void notifyBlockCreated(Block *block) {}

protected:
Expand Down Expand Up @@ -517,7 +523,7 @@ class OpBuilder : public Builder {
if (succeeded(tryFold(op, results)))
op->erase();
else if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
}

/// Overload to create or fold a single result operation.
Expand Down
26 changes: 24 additions & 2 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {

/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationRemoved(Operation *op) {}

/// Notify the listener that the pattern failed to match the given
Expand All @@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}

void notifyOperationInserted(Operation *op) override {
listener->notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
Expand Down Expand Up @@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);

/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpBefore(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
virtual void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator);

/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
void moveOpAfter(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);

/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
using PatternRewriter::cloneRegionBefore;

/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;

/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
Expand All @@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
detail::ConversionPatternRewriterImpl &getImpl();

private:
// Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;

void moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) override;
void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) override;

std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
if (failed(applyOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(applyOp);
listener->notifyOperationInserted(applyOp, /*previous=*/{});
return applyOp.getResult();
}

Expand Down Expand Up @@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
if (failed(minMaxOp->fold(constOperands, foldResults)) ||
foldResults.empty()) {
if (OpBuilder::Listener *listener = b.getListener())
listener->notifyOperationInserted(minMaxOp);
listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
return minMaxOp.getResult();
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});

// Create function entry block.
Block *block =
Expand Down Expand Up @@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
// Insert function into the module symbol table and assign it unique name.
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func);
rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});

// Create function entry block.
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
toMemrefOps.erase(op);
}

void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
// We only care about newly created ops.
if (previous.isSet())
return;

erasedOps.erase(op);

// Gather statistics about allocs.
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
}

private:
void notifyOperationInserted(Operation *op) override {
ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
ForwardingListener::notifyOperationInserted(op, previous);
// We only care about newly created ops.
if (previous.isSet())
return;
auto inserted = newOps.insert(op);
(void)inserted;
assert(inserted.second && "expected newly created op");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {

// Inline for-loop body operations into 'after' region.
for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
arg.moveBefore(afterBlock, afterBlock->end());
rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());

// Add incremented IV to yield operations
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
for (Operation *user : srcBuffer->getUsers()) {
if (hasEffect<MemoryEffects::Free>(user)) {
if (user->getBlock() == parallelCombiningParent->getBlock())
user->moveBefore(user->getBlock()->getTerminator());
rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
block->getOperations().insert(insertPoint, op);

if (listener)
listener->notifyOperationInserted(op);
listener->notifyOperationInserted(op, /*previous=*/{});
return op;
}

Expand Down Expand Up @@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp);
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
};
for (Region &region : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}

void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}

void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that this is enough to implement a "moveOpBefore" safely.

See: https://discourse.llvm.org/t/what-is-the-extent-of-the-changes-that-can-be-done-in-a-patternrewriter/76343/2

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to the fact that "moving ops" is not supported by the dialect conversion? The notifyOperationInserted callback is triggered after moving the op. The current location can be queried from the op itself. The previous location is passed as a parameter. That should be enough information to implement the rollback mechanism in the future.

"Moving an op" is a form of "inserting an op". Until now, we used notifyOperationInserted only for newly created ops. But in both cases we are inserting an op, only the previous location is different (moving an op: has a previous location, inserting a newly created op: was previously unlinked). (The callback is called notifyOperationInserted, not notifyOperationCreated.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me what operations should be considered "changed" when we move an op.

For example sinking:

A;
for (...) {
  if (...) {
    ...
  }
}

to

for (...) {
  if (...) {
    A 
    ...
  }
}

Is the "if" modified? The "for"?
Is notifyOperationInserted really meant to handle arbitrary moves and the handler must handle all this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have the same issue with other kinds of IR changes:

  • When erasing an operation, should we fire notifyOperationModified for the parent op? And for the parent's parent op? Etc.
  • When inserting a newly created operation, should we fire notifyOperationModified for the parent op into which we are inserting? And for the parent's parent op? Etc.
  • Same for erasing a block.

My conclusion was that we should not trigger a notifyOperationModified for things that have a separate notification. Otherwise, to be consistent, we would have to call notifyOperationModified for pretty much every IR change. That would potentially make it more difficult for listeners, because they would get duplicate notifications about the same IR change. (E.g., when inserting an op there would be two notifications.)

I think we should trigger notifyOperationModified only for "attribute changed", "property changed", "result type changed", "operand changed". Maybe also for "region entry block argument changed". And have separate callbacks for everything else.

Another thing that we could consider is giving notifyOperationInsertion, notificationOperationRemoved, notifyBlockCreated, etc. a default implementation that calls notifyOperationModified. Listeners could then decide what kind of granularity of notifications they would like to receive. (We already do something similar for notifyOperationReplaced(Operation *, Operation *).)

Is notifyOperationInserted really meant to handle arbitrary moves and the handler must handle all this?

I'd say that notifyOperationInserted should be called for all op insertions. Whether the inserted op is an already existing op or a newly created op is irrelevant. (Note the function is called notifyOperationInserted not notifyOperationCreated.)

But it raises the question whether notifyOperationRemoved should also be triggered. In my current implementation it is not, because the documentation of the callback says that it is triggered for "op erasure", not "op removal from a block". (I think the callback should be renamed to notifyOperationErased.)

/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
virtual void notifyOperationRemoved(Operation *op) {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

}

void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
}

void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
Block::iterator currentIterator = op->getIterator();
op->moveAfter(block, iterator);
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, currentIterator));
}
18 changes: 16 additions & 2 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
Block *cloned = mapping.lookup(&b);
impl->notifyCreatedBlock(cloned);
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) { notifyOperationInserted(op); });
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
}

void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
Expand Down Expand Up @@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
return impl->notifyMatchFailure(loc, reasonCallback);
}

void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
"moving single ops is not supported in a dialect conversion");
}

detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
Expand Down
18 changes: 13 additions & 5 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
}
}

void notifyOperationInserted(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op);
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
// Invalidate the finger print of the op that owns the block into which the
// op was inserted into.
invalidateFingerPrint(op->getParentOp());

// Also invalidate the finger print of the op that owns the block from which
// the op was moved from. (Only applicable if the op was moved.)
if (previous.isSet())
invalidateFingerPrint(previous.getBlock()->getParentOp());
}

void notifyOperationModified(Operation *op) override {
Expand Down Expand Up @@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op) override;
void notifyOperationInserted(Operation *op, InsertPoint previous) override;

/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
Expand Down Expand Up @@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
config.listener->notifyBlockRemoved(block);
}

void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationInserted(op);
config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
rewriter.moveOpBefore(extractionOp, loopLike);
rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
auto ifOp =
rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
// True branch.
op->moveBefore(&ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
ifOp.getThenRegion().front().begin());
rewriter.setInsertionPointAfter(op);
if (op->getNumResults() > 0)
rewriter.create<scf::YieldOp>(loc, op->getResults());
Expand Down
4 changes: 1 addition & 3 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
return failure();
if (!toBeHoisted->hasAttr("eligible"))
return failure();
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a workaround around the fact that there were no notifications for moveBefore. This used to trigger a failed "expensive pattern check" (IR changed but rewriter was not notified).

// op is modified.
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
rewriter.moveOpBefore(toBeHoisted, op);
return success();
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/IR/TestClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using namespace mlir;
namespace {

struct DumpNotifications : public OpBuilder::Listener {
void notifyOperationInserted(Operation *op) override {
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
}
};
Expand Down
Loading