Skip to content

Commit aca9019

Browse files
[mlir][transform] Check for invalidated iterators on payload IR mappings (#66369)
Add extra error checking (in debug mode) to detect cases where an iterator on "direct" payload IR mappings is invalidated (due to elements being removed). Such errors are hard to debug: they are often non-deterministic; sometimes the program crashes, sometimes it produces wrong results. Even when it crashes, the stack trace often points to completely unrelated code locations. Store a timestamp with each "direct" mapping. The timestamp is increased whenever an operation is performed that invaldiates an iterator on that mapping. A debug iterator is added that checks the timestamp as payload IR is enumerated.
1 parent 66aa9a2 commit aca9019

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ class TransformState {
170170
/// should be emitted when the value is used.
171171
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
172172

173+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
174+
/// Debug only: A timestamp is associated with each transform IR value, so
175+
/// that invalid iterator usage can be detected more reliably.
176+
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
177+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
178+
173179
/// The bidirectional mappings between transform IR values and payload IR
174180
/// operations, and the mapping between transform IR values and parameters.
175181
struct Mappings {
@@ -178,6 +184,11 @@ class TransformState {
178184
ParamMapping params;
179185
ValueMapping values;
180186
ValueMapping reverseValues;
187+
188+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
189+
TransformIRTimestampMapping timestamps;
190+
void incrementTimestamp(Value value) { ++timestamps[value]; }
191+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
181192
};
182193

183194
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
@@ -207,10 +218,26 @@ class TransformState {
207218
/// not enumerated. This function is helpful for transformations that apply to
208219
/// a particular handle.
209220
auto getPayloadOps(Value value) const {
221+
ArrayRef<Operation *> view = getPayloadOpsView(value);
222+
223+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
224+
// Memorize the current timestamp and make sure that it has not changed
225+
// when incrementing or dereferencing the iterator returned by this
226+
// function. The timestamp is incremented when the "direct" mapping is
227+
// resized; this would invalidate the iterator returned by this function.
228+
int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
229+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
230+
210231
// When ops are replaced/erased, they are replaced with nullptr (until
211232
// the data structure is compacted). Do not enumerate these ops.
212-
return llvm::make_filter_range(getPayloadOpsView(value),
213-
[](Operation *op) { return op != nullptr; });
233+
return llvm::make_filter_range(view, [=](Operation *op) {
234+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
235+
bool sameTimestamp =
236+
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
237+
assert(sameTimestamp && "iterator was invalidated during iteration");
238+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
239+
return op != nullptr;
240+
});
214241
}
215242

216243
/// Returns the list of parameters that the given transform IR value

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
310310
for (Operation *op : mappings.direct[opHandle])
311311
dropMappingEntry(mappings.reverse, op, opHandle);
312312
mappings.direct.erase(opHandle);
313+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
314+
// Payload IR is removed from the mapping. This invalidates the respective
315+
// iterators.
316+
mappings.incrementTimestamp(opHandle);
317+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
313318

314319
for (Value opResult : origOpFlatResults) {
315320
SmallVector<Value> resultHandles;
@@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping(
336341
Mappings &localMappings = getMapping(opHandle);
337342
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
338343
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
344+
345+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
346+
// Payload IR is removed from the mapping. This invalidates the respective
347+
// iterators.
348+
localMappings.incrementTimestamp(opHandle);
349+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
339350
}
340351
}
341352
}
@@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
774785
void transform::TransformState::compactOpHandles() {
775786
for (Value handle : opHandlesToCompact) {
776787
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
788+
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
789+
if (llvm::find(mappings.direct[handle], nullptr) !=
790+
mappings.direct[handle].end())
791+
// Payload IR is removed from the mapping. This invalidates the respective
792+
// iterators.
793+
mappings.incrementTimestamp(handle);
794+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
777795
llvm::erase_value(mappings.direct[handle], nullptr);
778796
}
779797
opHandlesToCompact.clear();

0 commit comments

Comments
 (0)