Skip to content

[mlir][transform] Fix crash when op is erased during transform.foreach #66357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);

/// Checks whether the transform op modifies the payload.
bool doesModifyPayload(transform::TransformOpInterface transform);
/// Checks whether the transform op reads the payload.
bool doesReadPayload(transform::TransformOpInterface transform);

/// Populates `consumedArguments` with positions of `block` arguments that are
/// consumed by the operations in the `block`.
void getConsumedBlockArguments(
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,20 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}

bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
}

bool transform::doesReadPayload(transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
}

void transform::getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
SmallVector<MemoryEffects::EffectInstance> effects;
Expand Down
18 changes: 16 additions & 2 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,8 +1121,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});

for (Operation *op : state.getPayloadOps(getTarget())) {
// Store payload ops in a vector because ops may be removed from the mapping
// by the TrackingRewriter while the iteration is in progress.
SmallVector<Operation *> targets =
llvm::to_vector(state.getPayloadOps(getTarget()));
for (Operation *op : targets) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
Expand Down Expand Up @@ -1152,13 +1155,24 @@ void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
BlockArgument iterVar = getIterationVariable();
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {

return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
})) {
consumesHandle(getTarget(), effects);
} else {
onlyReadsHandle(getTarget(), effects);
}

if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesModifyPayload(cast<TransformOpInterface>(&op));
})) {
modifiesPayload(effects);
} else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesReadPayload(cast<TransformOpInterface>(&op));
})) {
onlyReadsPayload(effects);
}

for (Value result : getResults())
producesHandle(result, effects);
}
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,28 @@ transform.with_pdl_patterns {

// -----

// CHECK-LABEL: func @consume_in_foreach()
// CHECK-NEXT: return
func.func @consume_in_foreach() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%2 = arith.constant 2 : index
%3 = arith.constant 3 : index
return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%f = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.foreach %f : !transform.any_op {
^bb2(%arg2: !transform.any_op):
// expected-remark @below {{erasing}}
transform.test_emit_remark_and_erase_operand %arg2, "erasing" : !transform.any_op
}
}

// -----

func.func @bar() {
scf.execute_region {
// expected-remark @below {{transform applied}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
emitRemark() << getRemark();
for (Operation *op : state.getPayloadOps(getTarget()))
op->erase();
rewriter.eraseOp(op);

if (getFailAfterErase())
return emitSilenceableError() << "silenceable error";
Expand Down