Skip to content

Commit 695a5a6

Browse files
[mlir][IR] Trigger notifyOperationRemoved callback for nested ops (#66771)
When cloning an op, the `notifyOperationInserted` callback is triggered for all nested ops. Similarly, the `notifyOperationRemoved` callback should be triggered for all nested ops when removing an op. Listeners may inspect the IR during a `notifyOperationRemoved` callback. Therefore, when multiple ops are removed in a single `RewriterBase::eraseOp` call, the notifications must be triggered in an order in which the ops could have been removed one-by-one: * Op removals must be interleaved with `notifyOperationRemoved` callbacks. A callback is triggered right before the respective op is removed. * Ops are removed post-order and in reverse order. Other traversal orders could delete an op that still has uses. (This is not avoidable in graph regions and with cyclic block graphs.) Differential Revision: Imported from https://reviews.llvm.org/D144193.
1 parent a317afa commit 695a5a6

File tree

7 files changed

+251
-25
lines changed

7 files changed

+251
-25
lines changed

mlir/include/mlir/IR/RegionKindInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
4343
/// not implement the RegionKindInterface.
4444
bool mayHaveSSADominance(Region &region);
4545

46+
/// Return "true" if the given region may be a graph region without SSA
47+
/// dominance. This function returns "true" in case the owner op is an
48+
/// unregistered op. It returns "false" if it is a registered op that does not
49+
/// implement the RegionKindInterface.
50+
bool mayBeGraphRegion(Region &region);
51+
4652
} // namespace mlir
4753

4854
#include "mlir/IR/RegionKindInterface.h.inc"

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
394394

395395
protected:
396396
void notifyOperationRemoved(Operation *op) override {
397-
// TODO: Walk can be removed when D144193 has landed.
398-
op->walk([&](Operation *op) {
399-
erasedOps.insert(op);
400-
// Erase if present.
401-
toMemrefOps.erase(op);
402-
});
397+
erasedOps.insert(op);
398+
// Erase if present.
399+
toMemrefOps.erase(op);
403400
}
404401

405402
void notifyOperationInserted(Operation *op) override {

mlir/lib/IR/PatternMatch.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "mlir/IR/PatternMatch.h"
1010
#include "mlir/IR/IRMapping.h"
11+
#include "mlir/IR/Iterators.h"
12+
#include "mlir/IR/RegionKindInterface.h"
1113

1214
using namespace mlir;
1315

@@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
275277
for (auto it : llvm::zip(op->getResults(), newValues))
276278
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
277279

278-
// Erase the op.
280+
// Erase op and notify listener.
279281
eraseOp(op);
280282
}
281283

@@ -295,17 +297,79 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
295297
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
296298
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
297299

298-
// Erase the old op.
300+
// Erase op and notify listener.
299301
eraseOp(op);
300302
}
301303

302304
/// This method erases an operation that is known to have no uses. The uses of
303305
/// the given operation *must* be known to be dead.
304306
void RewriterBase::eraseOp(Operation *op) {
305307
assert(op->use_empty() && "expected 'op' to have no uses");
306-
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
308+
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
309+
310+
// Fast path: If no listener is attached, the op can be dropped in one go.
311+
if (!rewriteListener) {
312+
op->erase();
313+
return;
314+
}
315+
316+
// Helper function that erases a single op.
317+
auto eraseSingleOp = [&](Operation *op) {
318+
#ifndef NDEBUG
319+
// All nested ops should have been erased already.
320+
assert(
321+
llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
322+
"expected empty regions");
323+
// All users should have been erased already if the op is in a region with
324+
// SSA dominance.
325+
if (!op->use_empty() && op->getParentOp())
326+
assert(mayBeGraphRegion(*op->getParentRegion()) &&
327+
"expected that op has no uses");
328+
#endif // NDEBUG
307329
rewriteListener->notifyOperationRemoved(op);
308-
op->erase();
330+
331+
// Explicitly drop all uses in case the op is in a graph region.
332+
op->dropAllUses();
333+
op->erase();
334+
};
335+
336+
// Nested ops must be erased one-by-one, so that listeners have a consistent
337+
// view of the IR every time a notification is triggered. Users must be
338+
// erased before definitions. I.e., post-order, reverse dominance.
339+
std::function<void(Operation *)> eraseTree = [&](Operation *op) {
340+
// Erase nested ops.
341+
for (Region &r : llvm::reverse(op->getRegions())) {
342+
// Erase all blocks in the right order. Successors should be erased
343+
// before predecessors because successor blocks may use values defined
344+
// in predecessor blocks. A post-order traversal of blocks within a
345+
// region visits successors before predecessors. Repeat the traversal
346+
// until the region is empty. (The block graph could be disconnected.)
347+
while (!r.empty()) {
348+
SmallVector<Block *> erasedBlocks;
349+
for (Block *b : llvm::post_order(&r.front())) {
350+
// Visit ops in reverse order.
351+
for (Operation &op :
352+
llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
353+
eraseTree(&op);
354+
// Do not erase the block immediately. This is not supprted by the
355+
// post_order iterator.
356+
erasedBlocks.push_back(b);
357+
}
358+
for (Block *b : erasedBlocks) {
359+
// Explicitly drop all uses in case there is a cycle in the block
360+
// graph.
361+
for (BlockArgument bbArg : b->getArguments())
362+
bbArg.dropAllUses();
363+
b->dropAllUses();
364+
b->erase();
365+
}
366+
}
367+
}
368+
// Then erase the enclosing op.
369+
eraseSingleOp(op);
370+
};
371+
372+
eraseTree(op);
309373
}
310374

311375
void RewriterBase::eraseBlock(Block *block) {

mlir/lib/IR/RegionKindInterface.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,17 @@ using namespace mlir;
1818
#include "mlir/IR/RegionKindInterface.cpp.inc"
1919

2020
bool mlir::mayHaveSSADominance(Region &region) {
21-
auto regionKindOp =
22-
dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
21+
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
2322
if (!regionKindOp)
2423
return true;
2524
return regionKindOp.hasSSADominance(region.getRegionNumber());
2625
}
26+
27+
bool mlir::mayBeGraphRegion(Region &region) {
28+
if (!region.getParentOp()->isRegistered())
29+
return true;
30+
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
31+
if (!regionKindOp)
32+
return false;
33+
return !regionKindOp.hasSSADominance(region.getRegionNumber());
34+
}

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
421421

422422
// If the operation is trivially dead - remove it.
423423
if (isOpTriviallyDead(op)) {
424-
notifyOperationRemoved(op);
425-
op->erase();
424+
eraseOp(op);
426425
changed = true;
427426

428427
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
567566
config.listener->notifyOperationRemoved(op);
568567

569568
addOperandsToWorklist(op->getOperands());
570-
op->walk([this](Operation *operation) {
571-
worklist.remove(operation);
572-
folder.notifyRemoval(operation);
573-
});
569+
worklist.remove(op);
570+
folder.notifyRemoval(op);
574571

575572
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
576573
strictModeFilteredOps.erase(op);

mlir/test/Transforms/test-strict-pattern-driver.mlir

Lines changed: 153 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
// CHECK-EN-LABEL: func @test_erase
1414
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
15-
// CHECK-EN: test.arg0
16-
// CHECK-EN: test.arg1
17-
// CHECK-EN-NOT: test.erase_op
15+
// CHECK-EN: "test.arg0"
16+
// CHECK-EN: "test.arg1"
17+
// CHECK-EN-NOT: "test.erase_op"
1818
func.func @test_erase() {
1919
%0 = "test.arg0"() : () -> (i32)
2020
%1 = "test.arg1"() : () -> (i32)
@@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
5151

5252
// CHECK-EN-LABEL: func @test_replace_with_erase_op
5353
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
54-
// CHECK-EN-NOT: test.replace_with_new_op
55-
// CHECK-EN-NOT: test.erase_op
54+
// CHECK-EN-NOT: "test.replace_with_new_op"
55+
// CHECK-EN-NOT: "test.erase_op"
5656

5757
// CHECK-EX-LABEL: func @test_replace_with_erase_op
5858
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
59-
// CHECK-EX-NOT: test.replace_with_new_op
60-
// CHECK-EX: test.erase_op
59+
// CHECK-EX-NOT: "test.replace_with_new_op"
60+
// CHECK-EX: "test.erase_op"
6161
func.func @test_replace_with_erase_op() {
6262
"test.replace_with_new_op"() {create_erase_op} : () -> ()
6363
return
@@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
8383
// in turn, replaces the successor with bb3.
8484
"test.implicit_change_op"() [^bb1] : () -> ()
8585
}
86+
87+
// -----
88+
89+
// CHECK-AN: notifyOperationRemoved: test.foo_b
90+
// CHECK-AN: notifyOperationRemoved: test.foo_a
91+
// CHECK-AN: notifyOperationRemoved: test.graph_region
92+
// CHECK-AN: notifyOperationRemoved: test.erase_op
93+
// CHECK-AN-LABEL: func @test_remove_graph_region()
94+
// CHECK-AN-NEXT: return
95+
func.func @test_remove_graph_region() {
96+
"test.erase_op"() ({
97+
test.graph_region {
98+
%0 = "test.foo_a"(%1) : (i1) -> (i1)
99+
%1 = "test.foo_b"(%0) : (i1) -> (i1)
100+
}
101+
}) : () -> ()
102+
return
103+
}
104+
105+
// -----
106+
107+
// CHECK-AN: notifyOperationRemoved: cf.br
108+
// CHECK-AN: notifyOperationRemoved: test.bar
109+
// CHECK-AN: notifyOperationRemoved: cf.br
110+
// CHECK-AN: notifyOperationRemoved: test.foo
111+
// CHECK-AN: notifyOperationRemoved: cf.br
112+
// CHECK-AN: notifyOperationRemoved: test.dummy_op
113+
// CHECK-AN: notifyOperationRemoved: test.erase_op
114+
// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
115+
// CHECK-AN-NEXT: return
116+
func.func @test_remove_cyclic_blocks() {
117+
"test.erase_op"() ({
118+
%x = "test.dummy_op"() : () -> (i1)
119+
cf.br ^bb1(%x: i1)
120+
^bb1(%arg0: i1):
121+
"test.foo"(%x) : (i1) -> ()
122+
cf.br ^bb2(%arg0: i1)
123+
^bb2(%arg1: i1):
124+
"test.bar"(%x) : (i1) -> ()
125+
cf.br ^bb1(%arg1: i1)
126+
}) : () -> ()
127+
return
128+
}
129+
130+
// -----
131+
132+
// CHECK-AN: notifyOperationRemoved: test.dummy_op
133+
// CHECK-AN: notifyOperationRemoved: test.bar
134+
// CHECK-AN: notifyOperationRemoved: test.qux
135+
// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
136+
// CHECK-AN: notifyOperationRemoved: test.nested_dummy
137+
// CHECK-AN: notifyOperationRemoved: cf.br
138+
// CHECK-AN: notifyOperationRemoved: test.foo
139+
// CHECK-AN: notifyOperationRemoved: test.erase_op
140+
// CHECK-AN-LABEL: func @test_remove_dead_blocks()
141+
// CHECK-AN-NEXT: return
142+
func.func @test_remove_dead_blocks() {
143+
"test.erase_op"() ({
144+
"test.dummy_op"() : () -> (i1)
145+
// The following blocks are not reachable. Still, ^bb2 should be deleted
146+
// befire ^bb1.
147+
^bb1(%arg0: i1):
148+
"test.foo"() : () -> ()
149+
cf.br ^bb2(%arg0: i1)
150+
^bb2(%arg1: i1):
151+
"test.nested_dummy"() ({
152+
"test.qux"() : () -> ()
153+
// The following block is unreachable.
154+
^bb3:
155+
"test.qux_unreachable"() : () -> ()
156+
}) : () -> ()
157+
"test.bar"() : () -> ()
158+
}) : () -> ()
159+
return
160+
}
161+
162+
// -----
163+
164+
// test.nested_* must be deleted before test.foo.
165+
// test.bar must be deleted before test.foo.
166+
167+
// CHECK-AN: notifyOperationRemoved: cf.br
168+
// CHECK-AN: notifyOperationRemoved: test.bar
169+
// CHECK-AN: notifyOperationRemoved: cf.br
170+
// CHECK-AN: notifyOperationRemoved: test.nested_b
171+
// CHECK-AN: notifyOperationRemoved: test.nested_a
172+
// CHECK-AN: notifyOperationRemoved: test.nested_d
173+
// CHECK-AN: notifyOperationRemoved: cf.br
174+
// CHECK-AN: notifyOperationRemoved: test.nested_e
175+
// CHECK-AN: notifyOperationRemoved: cf.br
176+
// CHECK-AN: notifyOperationRemoved: test.nested_c
177+
// CHECK-AN: notifyOperationRemoved: test.foo
178+
// CHECK-AN: notifyOperationRemoved: cf.br
179+
// CHECK-AN: notifyOperationRemoved: test.dummy_op
180+
// CHECK-AN: notifyOperationRemoved: test.erase_op
181+
// CHECK-AN-LABEL: func @test_remove_nested_ops()
182+
// CHECK-AN-NEXT: return
183+
func.func @test_remove_nested_ops() {
184+
"test.erase_op"() ({
185+
%x = "test.dummy_op"() : () -> (i1)
186+
cf.br ^bb1(%x: i1)
187+
^bb1(%arg0: i1):
188+
"test.foo"() ({
189+
"test.nested_a"() : () -> ()
190+
"test.nested_b"() : () -> ()
191+
^dead1:
192+
"test.nested_c"() : () -> ()
193+
cf.br ^dead3
194+
^dead2:
195+
"test.nested_d"() : () -> ()
196+
^dead3:
197+
"test.nested_e"() : () -> ()
198+
cf.br ^dead2
199+
}) : () -> ()
200+
cf.br ^bb2(%arg0: i1)
201+
^bb2(%arg1: i1):
202+
"test.bar"(%x) : (i1) -> ()
203+
cf.br ^bb1(%arg1: i1)
204+
}) : () -> ()
205+
return
206+
}
207+
208+
// -----
209+
210+
// CHECK-AN: notifyOperationRemoved: test.qux
211+
// CHECK-AN: notifyOperationRemoved: cf.br
212+
// CHECK-AN: notifyOperationRemoved: test.foo
213+
// CHECK-AN: notifyOperationRemoved: cf.br
214+
// CHECK-AN: notifyOperationRemoved: test.bar
215+
// CHECK-AN: notifyOperationRemoved: cf.cond_br
216+
// CHECK-AN-LABEL: func @test_remove_diamond(
217+
// CHECK-AN-NEXT: return
218+
func.func @test_remove_diamond(%c: i1) {
219+
"test.erase_op"() ({
220+
cf.cond_br %c, ^bb1, ^bb2
221+
^bb1:
222+
"test.foo"() : () -> ()
223+
cf.br ^bb3
224+
^bb2:
225+
"test.bar"() : () -> ()
226+
cf.br ^bb3
227+
^bb3:
228+
"test.qux"() : () -> ()
229+
}) : () -> ()
230+
return
231+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ struct TestPatternDriver
239239
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
240240
};
241241

242+
struct DumpNotifications : public RewriterBase::Listener {
243+
void notifyOperationRemoved(Operation *op) override {
244+
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
245+
}
246+
};
247+
242248
struct TestStrictPatternDriver
243249
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
244250
public:
@@ -275,7 +281,9 @@ struct TestStrictPatternDriver
275281
}
276282
});
277283

284+
DumpNotifications dumpNotifications;
278285
GreedyRewriteConfig config;
286+
config.listener = &dumpNotifications;
279287
if (strictMode == "AnyOp") {
280288
config.strictMode = GreedyRewriteStrictness::AnyOp;
281289
} else if (strictMode == "ExistingAndNewOps") {

0 commit comments

Comments
 (0)