Skip to content

[mlir][transform] Check for invalidated iterators on payload IR mappings #66369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ class TransformState {
/// should be emitted when the value is used.
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Debug only: A timestamp is associated with each transform IR value, so
/// that invalid iterator usage can be detected more reliably.
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

/// The bidirectional mappings between transform IR values and payload IR
/// operations, and the mapping between transform IR values and parameters.
struct Mappings {
Expand All @@ -178,6 +184,11 @@ class TransformState {
ParamMapping params;
ValueMapping values;
ValueMapping reverseValues;

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
TransformIRTimestampMapping timestamps;
void incrementTimestamp(Value value) { ++timestamps[value]; }
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};

friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
Expand Down Expand Up @@ -207,10 +218,26 @@ class TransformState {
/// not enumerated. This function is helpful for transformations that apply to
/// a particular handle.
auto getPayloadOps(Value value) const {
ArrayRef<Operation *> view = getPayloadOpsView(value);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "direct" mapping is
// resized; this would invalidate the iterator returned by this function.
int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

// When ops are replaced/erased, they are replaced with nullptr (until
// the data structure is compacted). Do not enumerate these ops.
return llvm::make_filter_range(getPayloadOpsView(value),
[](Operation *op) { return op != nullptr; });
return llvm::make_filter_range(view, [=](Operation *op) {
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
bool sameTimestamp =
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
assert(sameTimestamp && "iterator was invalidated during iteration");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return op != nullptr;
});
}

/// Returns the list of parameters that the given transform IR value
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

for (Value opResult : origOpFlatResults) {
SmallVector<Value> resultHandles;
Expand All @@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping(
Mappings &localMappings = getMapping(opHandle);
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);

#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
}
}
Expand Down Expand Up @@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
llvm::erase_value(mappings.direct[handle], nullptr);
}
opHandlesToCompact.clear();
Expand Down