Skip to content

Commit 4f63252

Browse files
[mlir][transform] Fix crash when op is erased during transform.foreach (#66357)
Fixes a crash when an op, that is mapped to handle that a `transform.foreach` iterates over, was erased (through the `TrackingRewriter`). Erasing an op removes it from all mappings and invalidates iterators. This is already taken care of when an op is iterating over payload ops in its `apply` method, but not when another transform op is erasing a tracked payload op.
1 parent 3c81a0b commit 4f63252

File tree

5 files changed

+58
-3
lines changed

5 files changed

+58
-3
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
11561156
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
11571157
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
11581158

1159+
/// Checks whether the transform op modifies the payload.
1160+
bool doesModifyPayload(transform::TransformOpInterface transform);
1161+
/// Checks whether the transform op reads the payload.
1162+
bool doesReadPayload(transform::TransformOpInterface transform);
1163+
11591164
/// Populates `consumedArguments` with positions of `block` arguments that are
11601165
/// consumed by the operations in the `block`.
11611166
void getConsumedBlockArguments(

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
19041904
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
19051905
}
19061906

1907+
bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1908+
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1909+
SmallVector<MemoryEffects::EffectInstance> effects;
1910+
iface.getEffects(effects);
1911+
return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1912+
}
1913+
1914+
bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1915+
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1916+
SmallVector<MemoryEffects::EffectInstance> effects;
1917+
iface.getEffects(effects);
1918+
return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1919+
}
1920+
19071921
void transform::getConsumedBlockArguments(
19081922
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
19091923
SmallVector<MemoryEffects::EffectInstance> effects;

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
11211121
transform::TransformResults &results,
11221122
transform::TransformState &state) {
11231123
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1124-
1125-
for (Operation *op : state.getPayloadOps(getTarget())) {
1124+
// Store payload ops in a vector because ops may be removed from the mapping
1125+
// by the TrackingRewriter while the iteration is in progress.
1126+
SmallVector<Operation *> targets =
1127+
llvm::to_vector(state.getPayloadOps(getTarget()));
1128+
for (Operation *op : targets) {
11261129
auto scope = state.make_region_scope(getBody());
11271130
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
11281131
return DiagnosedSilenceableFailure::definiteFailure();
@@ -1152,13 +1155,24 @@ void transform::ForeachOp::getEffects(
11521155
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
11531156
BlockArgument iterVar = getIterationVariable();
11541157
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1158+
11551159
return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
11561160
})) {
11571161
consumesHandle(getTarget(), effects);
11581162
} else {
11591163
onlyReadsHandle(getTarget(), effects);
11601164
}
11611165

1166+
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1167+
return doesModifyPayload(cast<TransformOpInterface>(&op));
1168+
})) {
1169+
modifiesPayload(effects);
1170+
} else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1171+
return doesReadPayload(cast<TransformOpInterface>(&op));
1172+
})) {
1173+
onlyReadsPayload(effects);
1174+
}
1175+
11621176
for (Value result : getResults())
11631177
producesHandle(result, effects);
11641178
}

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,28 @@ transform.with_pdl_patterns {
691691

692692
// -----
693693

694+
// CHECK-LABEL: func @consume_in_foreach()
695+
// CHECK-NEXT: return
696+
func.func @consume_in_foreach() {
697+
%0 = arith.constant 0 : index
698+
%1 = arith.constant 1 : index
699+
%2 = arith.constant 2 : index
700+
%3 = arith.constant 3 : index
701+
return
702+
}
703+
704+
transform.sequence failures(propagate) {
705+
^bb1(%arg1: !transform.any_op):
706+
%f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
707+
transform.foreach %f : !transform.any_op {
708+
^bb2(%arg2: !transform.any_op):
709+
// expected-remark @below {{erasing}}
710+
transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
711+
}
712+
}
713+
714+
// -----
715+
694716
func.func @bar() {
695717
scf.execute_region {
696718
// expected-remark @below {{transform applied}}

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
390390
transform::TransformResults &results, transform::TransformState &state) {
391391
emitRemark() << getRemark();
392392
for (Operation *op : state.getPayloadOps(getTarget()))
393-
op->erase();
393+
rewriter.eraseOp(op);
394394

395395
if (getFailAfterErase())
396396
return emitSilenceableError() << "silenceable error";

0 commit comments

Comments
 (0)