Skip to content

[mlir][transform] Check for invalidated iterators on payload values #66472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

Same as #66369 but for payload values. (#66369 added checks only for payload operations.)

It was necessary to change the signature of getPayloadValues to return an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the LLVM_ENABLE_ABI_BREAKING_CHECKS check was inverted.

Same as llvm#66369 but for payload values. (llvm#66369 added checks only for payload operations.)

It was necessary to change the signature of `getPayloadValues` to return an iterator. This is now similar to payload operations.

Fixes an issue in llvm#66369 where the `LLVM_ENABLE_ABI_BREAKING_CHECKS` check was inverted.
@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Changes Same as #66369 but for payload values. (#66369 added checks only for payload operations.)

It was necessary to change the signature of getPayloadValues to return an iterator. This is now similar to payload operations.

Fixes an issue in #66369 where the LLVM_ENABLE_ABI_BREAKING_CHECKS check was inverted.

Full diff: https://github.com/llvm/llvm-project/pull/66472.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h (+3-3)
  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+57-13)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+22-20)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+1-3)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-4)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index 1a6afc58fef2704..c8888f294f6ca1d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -95,15 +95,15 @@ class SingleValueMatcherOpTrait
                                     TransformResults &results,
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
-    ValueRange payload = state.getPayloadValues(operandHandle);
-    if (payload.size() != 1) {
+    auto payload = state.getPayloadValues(operandHandle);
+    if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
              << "SingleValueMatchOpTrait requires the value handle to point to "
                 "a single payload value";
     }
 
     return cast<OpTy>(this->getOperation())
-        .matchValue(payload[0], results, state);
+        .matchValue(*payload.begin(), results, state);
   }
 
   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 86af59142b77d9c..31a93b05cf7a153 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -170,7 +170,7 @@ class TransformState {
   /// should be emitted when the value is used.
   using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// Debug only: A timestamp is associated with each transform IR value, so
   /// that invalid iterator usage can be detected more reliably.
   using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
@@ -185,7 +185,7 @@ class TransformState {
     ValueMapping values;
     ValueMapping reverseValues;
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     TransformIRTimestampMapping timestamps;
     void incrementTimestamp(Value value) { ++timestamps[value]; }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -220,7 +220,7 @@ class TransformState {
   auto getPayloadOps(Value value) const {
     ArrayRef<Operation *> view = getPayloadOpsView(value);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     // Memorize the current timestamp and make sure that it has not changed
     // when incrementing or dereferencing the iterator returned by this
     // function. The timestamp is incremented when the "direct" mapping is
@@ -231,7 +231,7 @@ class TransformState {
     // When ops are replaced/erased, they are replaced with nullptr (until
     // the data structure is compacted). Do not enumerate these ops.
     return llvm::make_filter_range(view, [=](Operation *op) {
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       bool sameTimestamp =
           currentTimestamp == this->getMapping(value).timestamps.lookup(value);
       assert(sameTimestamp && "iterator was invalidated during iteration");
@@ -244,9 +244,29 @@ class TransformState {
   /// corresponds to.
   ArrayRef<Attribute> getParams(Value value) const;
 
-  /// Returns the list of payload IR values that the given transform IR value
-  /// corresponds to.
-  ArrayRef<Value> getPayloadValues(Value handleValue) const;
+  /// Returns an iterator that enumerates all payload IR values that the given
+  /// transform IR value corresponds to.
+  auto getPayloadValues(Value handleValue) const {
+    ArrayRef<Value> view = getPayloadValuesView(handleValue);
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+    // Memorize the current timestamp and make sure that it has not changed
+    // when incrementing or dereferencing the iterator returned by this
+    // function. The timestamp is incremented when the "values" mapping is
+    // resized; this would invalidate the iterator returned by this function.
+    int64_t currentTimestamp =
+        getMapping(handleValue).timestamps.lookup(handleValue);
+    return llvm::make_filter_range(view, [=](Value v) {
+      bool sameTimestamp =
+          currentTimestamp ==
+          this->getMapping(handleValue).timestamps.lookup(handleValue);
+      assert(sameTimestamp && "iterator was invalidated during iteration");
+      return true;
+    });
+#else
+    return llvm::make_range(view.begin(), view.end());
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  }
 
   /// Populates `handles` with all handles pointing to the given Payload IR op.
   /// Returns success if such handles exist, failure otherwise.
@@ -501,12 +521,15 @@ class TransformState {
   LogicalResult updateStateFromResults(const TransformResults &results,
                                        ResultRange opResults);
 
-  /// Returns a list of all ops that the given transform IR value corresponds to
-  /// at the time when this function is called. In case an op was erased, the
-  /// returned list contains nullptr. This function is helpful for
-  /// transformations that apply to a particular handle.
+  /// Returns a list of all ops that the given transform IR value corresponds
+  /// to. In case an op was erased, the returned list contains nullptr. This
+  /// function is helpful for transformations that apply to a particular handle.
   ArrayRef<Operation *> getPayloadOpsView(Value value) const;
 
+  /// Returns a list of payload IR values that the given transform IR value
+  /// corresponds to.
+  ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
+
   /// Sets the payload IR ops associated with the given transform IR value
   /// (handle). A payload op may be associated multiple handles as long as
   /// at most one of them gets consumed by further transformations.
@@ -774,7 +797,8 @@ class TransformResults {
   /// corresponds to the given list of payload IR ops. Each result must be set
   /// by the transformation exactly once in case of transformation succeeding.
   /// The value must have a type implementing TransformHandleTypeInterface.
-  template <typename Range> void set(OpResult value, Range &&ops) {
+  template <typename Range>
+  void set(OpResult value, Range &&ops) {
     int64_t position = value.getResultNumber();
     assert(position < static_cast<int64_t>(operations.size()) &&
            "setting results for a non-existent handle");
@@ -805,7 +829,27 @@ class TransformResults {
   /// set by the transformation exactly once in case of transformation
   /// succeeding. The value must have a type implementing
   /// TransformValueHandleTypeInterface.
-  void setValues(OpResult handle, ValueRange values);
+  template <typename Range>
+  void setValues(OpResult handle, Range &&values) {
+    int64_t position = handle.getResultNumber();
+    assert(position < static_cast<int64_t>(this->values.size()) &&
+           "setting values for a non-existent handle");
+    assert(this->values[position].data() == nullptr && "values already set");
+    assert(operations[position].data() == nullptr &&
+           "another kind of results already set");
+    assert(params[position].data() == nullptr &&
+           "another kind of results already set");
+    this->values.replace(position, std::forward<Range>(values));
+  }
+
+  /// Indicates that the result of the transform IR op at the given position
+  /// corresponds to the given range of payload IR values. Each result must be
+  /// set by the transformation exactly once in case of transformation
+  /// succeeding. The value must have a type implementing
+  /// TransformValueHandleTypeInterface.
+  void setValues(OpResult handle, std::initializer_list<Value> values) {
+    setValues(handle, ArrayRef<Value>(values));
+  }
 
   /// Indicates that the result of the transform IR op at the given position
   /// corresponds to the given range of mapped values. All mapped values are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 7b8bf6fc5d8f5a4..fb021ed76242e90 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -728,7 +728,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
 
   Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
   if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
-    results.setValues(cast<OpResult>(getResult()), result);
+    results.setValues(cast<OpResult>(getResult()), {result});
     return DiagnosedSilenceableFailure::success();
   }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9cac178d3c2b869..fd2cf8816ae2162 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -75,7 +75,7 @@ ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
 }
 
 ArrayRef<Value>
-transform::TransformState::getPayloadValues(Value handleValue) const {
+transform::TransformState::getPayloadValuesView(Value handleValue) const {
   const ValueMapping &mapping = getMapping(handleValue).values;
   auto iter = mapping.find(handleValue);
   assert(iter != mapping.end() && "cannot find mapping for value handle "
@@ -310,7 +310,7 @@ void transform::TransformState::forgetMapping(Value opHandle,
   for (Operation *op : mappings.direct[opHandle])
     dropMappingEntry(mappings.reverse, op, opHandle);
   mappings.direct.erase(opHandle);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
   // Payload IR is removed from the mapping. This invalidates the respective
   // iterators.
   mappings.incrementTimestamp(opHandle);
@@ -322,6 +322,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
     for (Value resultHandle : resultHandles) {
       Mappings &localMappings = getMapping(resultHandle);
       dropMappingEntry(localMappings.values, resultHandle, opResult);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(resultHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
       dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
     }
   }
@@ -333,6 +338,11 @@ void transform::TransformState::forgetValueMapping(
   for (Value payloadValue : mappings.reverseValues[valueHandle])
     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
   mappings.values.erase(valueHandle);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+  // Payload IR is removed from the mapping. This invalidates the respective
+  // iterators.
+  mappings.incrementTimestamp(valueHandle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   for (Operation *payloadOp : payloadOperations) {
     SmallVector<Value> opHandles;
@@ -342,7 +352,7 @@ void transform::TransformState::forgetValueMapping(
       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
 
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
       // Payload IR is removed from the mapping. This invalidates the respective
       // iterators.
       localMappings.incrementTimestamp(opHandle);
@@ -439,6 +449,11 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
     // between the handles and the IR objects
     if (!replacement) {
       dropMappingEntry(mappings.values, handle, value);
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+      // Payload IR is removed from the mapping. This invalidates the respective
+      // iterators.
+      mappings.incrementTimestamp(handle);
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     } else {
       auto it = mappings.values.find(handle);
       if (it == mappings.values.end())
@@ -647,7 +662,7 @@ void transform::TransformState::recordValueHandleInvalidation(
     OpOperand &valueHandle,
     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   // Invalidate other handles to the same value.
-  for (Value payloadValue : getPayloadValues(valueHandle.get())) {
+  for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
     SmallVector<Value> otherValueHandles;
     (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
     for (Value otherHandle : otherValueHandles) {
@@ -785,7 +800,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
-#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
     if (llvm::find(mappings.direct[handle], nullptr) !=
         mappings.direct[handle].end())
       // Payload IR is removed from the mapping. This invalidates the respective
@@ -846,7 +861,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
         DiagnosedSilenceableFailure check =
             checkRepeatedConsumptionInOperand<Value>(
-                getPayloadValues(operand.get()), transform,
+                getPayloadValuesView(operand.get()), transform,
                 operand.getOperandNumber());
         if (!check.succeeded()) {
           FULL_LDBG("----FAILED\n");
@@ -912,7 +927,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
       continue;
     }
     if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
-      for (Value payloadValue : getPayloadValues(operand)) {
+      for (Value payloadValue : getPayloadValuesView(operand)) {
         if (llvm::isa<OpResult>(payloadValue)) {
           origAssociatedOps.push_back(payloadValue.getDefiningOp());
           continue;
@@ -1170,19 +1185,6 @@ void transform::TransformResults::setParams(
   this->params.replace(position, params);
 }
 
-void transform::TransformResults::setValues(OpResult handle,
-                                            ValueRange values) {
-  int64_t position = handle.getResultNumber();
-  assert(position < static_cast<int64_t>(this->values.size()) &&
-         "setting values for a non-existent handle");
-  assert(this->values[position].data() == nullptr && "values already set");
-  assert(operations[position].data() == nullptr &&
-         "another kind of results already set");
-  assert(params[position].data() == nullptr &&
-         "another kind of results already set");
-  this->values.replace(position, values);
-}
-
 void transform::TransformResults::setMappedValues(
     OpResult handle, ArrayRef<MappedValue> values) {
   DiagnosedSilenceableFailure diag = dispatchMappedValues(
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..cd4f628f1459ab7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1378,9 +1378,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
   SmallVector<Attribute> params;
-  ArrayRef<Value> values = state.getPayloadValues(getValue());
-  params.reserve(values.size());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getValue())) {
     Type type = value.getType();
     if (getElemental()) {
       if (auto shaped = dyn_cast<ShapedType>(type)) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 59f045de3246f6b..e8c25aca237251a 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -136,7 +136,7 @@ DiagnosedSilenceableFailure
 mlir::test::TestProduceValueHandleToSelfOperand::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  results.setValues(llvm::cast<OpResult>(getOut()), getIn());
+  results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -265,8 +265,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  ArrayRef<Value> values = state.getPayloadValues(getIn());
-  for (Value value : values) {
+  for (Value value : state.getPayloadValues(getIn())) {
     std::string note;
     llvm::raw_string_ostream os(note);
     if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
@@ -712,7 +711,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
 DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
-  results.setValues(llvm::cast<OpResult>(getOut()), Value());
+  results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
   return DiagnosedSilenceableFailure::success();
 }
 

@matthias-springer matthias-springer merged commit 085075a into llvm:main Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants