Skip to content

[mlir][tensor] add tensor insert/extract op folders #142458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025

Conversation

asraa
Copy link
Contributor

@asraa asraa commented Jun 2, 2025

Adds a few canonicalizers, folders, and rewrite patterns to tensor ops:

  • tensor.insert folder: insert into a constant is replaced with a new constant
  • tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value
  • rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions)

Signed-off-by: Asra Ali <asraa@google.com>
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (asraa)

Changes

Adds a few canonicalizers, folders, and rewrite patterns to tensor ops:

  • tensor.insert folder: insert into a constant is replaced with a new constant
  • tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value
  • rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions)

Full diff: https://github.com/llvm/llvm-project/pull/142458.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+4)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+165)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+26-3)
  • (added) mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir (+31)
  • (modified) mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (+13)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index eb550bb469b9f..e8e1342ef36fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
+/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
+/// source tensor.
+void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..c0885a3763827 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 30ca20fc0d883..f2a7220b4bedc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include <algorithm>
 #include <optional>
+#include <vector>
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
   }
 };
 
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
+/// tensor<12xf64>
+/// %extracted_element = tensor.extract %val[%c10] :
+/// tensor<12xf64>
+///
+/// to
+///
+/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
+struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const final {
+    auto collapseOp =
+        extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp)
+      return failure();
+    if (!collapseOp.getSrcType().hasStaticShape())
+      return failure();
+
+    auto sourceSizes = collapseOp.getSrcType().getShape();
+
+    SmallVector<Value> indices(extractOp.getIndices().begin(),
+                               extractOp.getIndices().end());
+    SmallVector<Value> sourceIndices;
+    for (auto [index, group] :
+         llvm::zip(indices, collapseOp.getReassociationIndices())) {
+      assert(!group.empty() && "association indices groups cannot be empty");
+      auto groupSize = group.size();
+
+      if (groupSize == 1) {
+        sourceIndices.push_back(index);
+        continue;
+      }
+
+      SmallVector<int64_t> basis =
+          llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+      auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+          extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
+      llvm::append_range(sourceIndices, delinearize.getResults());
+    }
+    if (collapseOp.getReassociationIndices().empty()) {
+      auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+      int64_t srcRank =
+          cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
+      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, extractOp.getLoc(), zeroAffineMap,
+          ArrayRef<OpFoldResult>{});
+      for (int64_t i = 0; i < srcRank; i++) {
+        sourceIndices.push_back(
+            getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+        extractOp, collapseOp.getSrc(), sourceIndices);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
   return success();
 }
 
+/// If we have an ExtractOp consuming an InsertOp with the same
+/// indices, we can return the InsertOp's scalar directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsert(ExtractOp extractOp) {
+  auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
+
+  auto isSame = [](Value a, Value b) {
+    return getAsOpFoldResult(a) == getAsOpFoldResult(b);
+  };
+  if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
+      llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
+    return insertOp.getScalar();
+
+  return {};
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   if (Attribute tensor = adaptor.getTensor()) {
     // If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
       return elementsAttr.getValues<Attribute>()[indices];
   }
 
+  if (Value result = foldExtractAfterInsert(*this))
+    return result;
+
   return {};
 }
 
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractFromTensorCast>(context);
 }
 
+void mlir::tensor::populateFoldCollapseExtractPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractFromCollapseShape>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
 // InsertOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Pattern to fold an insert op of a constant destination and scalar to a new
+/// constant.
+///
+/// Example:
+/// ```
+///   %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+///   %c0 = arith.constant 0 : index
+///   %c4_f32 = arith.constant 4.0 : f32
+///   %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+///   %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
+/// ```
+class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    // Requires a ranked tensor type.
+    auto destType =
+        llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
+    if (!destType)
+      return failure();
+
+    // Pattern requires constant indices
+    SmallVector<uint64_t, 8> indices;
+    for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
+      auto indiceAttr = dyn_cast<Attribute>(indice);
+      if (!indiceAttr)
+        return failure();
+      indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
+    }
+
+    // Requires a constant scalar to insert
+    OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
+    Attribute scalarAttr = dyn_cast<Attribute>(scalar);
+    if (!scalarAttr)
+      return failure();
+
+    if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
+            insertOp.getDest().getDefiningOp())) {
+      if (auto sourceAttr =
+              llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
+        // Update the attribute at the inserted index.
+        auto sourceValues = sourceAttr.getValues<Attribute>();
+        auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
+        std::vector<Attribute> updatedValues;
+        updatedValues.reserve(sourceAttr.getNumElements());
+        for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
+          updatedValues.push_back(i == flattenedIndex ? scalarAttr
+                                                      : sourceValues[i]);
+        }
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            insertOp, sourceAttr.getType(),
+            DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
+                                   updatedValues));
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+} // namespace
+
 void InsertOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<InsertOpConstantFold>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index cdcd7f305d2d9..0abec7e01d184 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
 // -----
 
 // CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
   %const_0 = arith.constant 0 : index
   %const_1 = arith.constant 1 : index
   %const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
   %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
 
-  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
-  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+  // Fold an extract after an insert.
+  // CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
+  %inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
+  %ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
+
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
 }
 
 // -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   return %ins_1 : tensor<4xf32>
 }
 
+
+// -----
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+  // Fold an insert into a splat.
+  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+  // CHECK-LITERAL:
+  // CHECK-NEXT: return %[[C4]]
+  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4_i32 = arith.constant 4 : i32
+  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+  return %inserted : tensor<2x2xi32>
+}
+
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
new file mode 100644
index 0000000000000..c301f494a7c87
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
+
+// CHECK-LABEL: @extract_from_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
+func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
+  %extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
+  %extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
+  func.return %extracted, %extracted_0 : i8, i8
+}
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
+// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
+// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
+
+// -----
+
+// CHECK-LABEL: @extract_from_static_shape
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
+  %1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index e435130c2a417..0e191c32f009e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,11 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
       llvm::cl::init(false)};
 
+  Option<bool> testFoldExtractFromCollapseShape{
+      *this, "test-fold-extract-from-collapse-shape",
+      llvm::cl::desc("Test folding of extract from collapse_shape"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateFoldCollapseExtractPatterns(patterns);
+  (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
       return signalPassFailure();
   }
+  if (testFoldExtractFromCollapseShape)
+    applyFoldExtractFromCollapseShapePatterns(rootOp);
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();

@j2kun j2kun self-requested a review June 2, 2025 19:31
Copy link
Contributor

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Will wait for tests to pass and then merge

@j2kun j2kun merged commit 34d8275 into llvm:main Jun 3, 2025
14 checks passed
@joker-eph
Copy link
Collaborator

I have some concerns about folders in the tensor dialect, in particular:

tensor.insert folder: insert into a constant is replaced with a new constant

This kind of folders can quickly lead to a size explosion because of the nature of tensor constants which can be pretty large.
I believe we shied away from these folders so far in the absence of good cost model and relied on heuristics in canonicalizer to restrict these to "small" tensors.

return failure();

// Pattern requires constant indices
SmallVector<uint64_t, 8> indices;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    SmallVector<uint64_t> indices;

Nit: don't use a specific size for SmallVector without a good reasons to.

// Update the attribute at the inserted index.
auto sourceValues = sourceAttr.getValues<Attribute>();
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
std::vector<Attribute> updatedValues;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is an unfortunate slow path for what will be converted ultimately to a vector of int or float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting to copy the source values instead of looping? Or to switch on the int / float type and then convert back to attributes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saying we should be able to avoid using individual attribues per-element in the common case indeed.
We may be lacking (possibly templated?) helpers to do this conveniently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in retrospect it makes sense to type switch on the attribute and the options are DenseIntElementsAttr and DenseFPElementsAttr.

j2kun pushed a commit that referenced this pull request Jun 5, 2025
…ization (#142671)

Follow ups from #142458
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.

Signed-off-by: Asra Ali <asraa@google.com>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 5, 2025
…f canonicalization (#142671)

Follow ups from llvm/llvm-project#142458
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.

Signed-off-by: Asra Ali <asraa@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants