Skip to content

[MLIR][Linalg] pack, unpack to take memref inputs #129036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

ita9naiwa
Copy link
Contributor

@ita9naiwa ita9naiwa commented Feb 27, 2025

#129004

  • Change to use ShapedType, not RankedTensorType
    • accordingly add branches to handle both MemrefType and TensorType
    • Add Memref::CollapseShapeOp::inferCollapsedType to handle some MemrefType

@ita9naiwa ita9naiwa changed the title [draft][mlir][ [draft][mlir][linalg] pack, unpack to take memref inputs Feb 27, 2025
Copy link

github-actions bot commented Feb 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2025

@llvm/pr-subscribers-mlir-memref

Author: Hyunsung Lee (ita9naiwa)

Changes

#129004

  • Change to use ShapedType, not RankedTensorType
    • accordingly add branches to handle both MemrefType and TensorType
    • Add Memref::CollapseShapeOp::inferCollapsedType to handle some MemrefType

Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+12-11)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+15-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+19-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+31-1)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+6-6)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 
     static MemRefType computeCollapsedType(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type,
+                           SmallVector<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
     }
 
-    RankedTensorType srcPadType = srcPadOp.getSourceType();
+    ShapedType srcPadType = srcPadOp.getSourceType();
     SmallVector<OpFoldResult, 4> newSizes;
     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
       if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
-  return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  ShapedType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
+  ShapedType stripMinedType;
+  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+    stripMinedType =
+        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+    stripMinedType =
+        MemRefType::get(stripMinedShape, memrefType.getElementType());
+  }
+  ShapedType collapsedType;
+  if (stripMinedType.isa<TensorType>()) {
+    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+        cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+  } else if (stripMinedType.isa<MemRefType>()) {
+    collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+        cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+  }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedTensorType.getElementType());
+      loc, dims, stripMinedType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+                                    ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.slice(currentDim, dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamic))
+      size = ShapedType::kDynamic;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+  return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+    MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+  return inferCollapsedType(
+      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                type.getContext(), reassociation)));
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                            ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Hyunsung Lee (ita9naiwa)

Changes

#129004

  • Change to use ShapedType, not RankedTensorType
    • accordingly add branches to handle both MemrefType and TensorType
    • Add Memref::CollapseShapeOp::inferCollapsedType to handle some MemrefType

Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+12-11)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+15-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+19-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+31-1)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+6-6)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 
     static MemRefType computeCollapsedType(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type,
+                           SmallVector<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
     }
 
-    RankedTensorType srcPadType = srcPadOp.getSourceType();
+    ShapedType srcPadType = srcPadOp.getSourceType();
     SmallVector<OpFoldResult, 4> newSizes;
     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
       if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
-  return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  ShapedType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
+  ShapedType stripMinedType;
+  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+    stripMinedType =
+        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+    stripMinedType =
+        MemRefType::get(stripMinedShape, memrefType.getElementType());
+  }
+  ShapedType collapsedType;
+  if (stripMinedType.isa<TensorType>()) {
+    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+        cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+  } else if (stripMinedType.isa<MemRefType>()) {
+    collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+        cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+  }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedTensorType.getElementType());
+      loc, dims, stripMinedType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+                                    ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.slice(currentDim, dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamic))
+      size = ShapedType::kDynamic;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+  return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+    MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+  return inferCollapsedType(
+      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                type.getContext(), reassociation)));
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                            ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 1, 2025

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

#129004

  • Change to use ShapedType, not RankedTensorType
    • accordingly add branches to handle both MemrefType and TensorType
    • Add Memref::CollapseShapeOp::inferCollapsedType to handle some MemrefType

Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+12-11)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td (+1)
  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+15-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+19-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+31-1)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+6-6)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 
     static MemRefType computeCollapsedType(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type,
+                           SmallVector<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
     }
 
-    RankedTensorType srcPadType = srcPadOp.getSourceType();
+    ShapedType srcPadType = srcPadOp.getSourceType();
     SmallVector<OpFoldResult, 4> newSizes;
     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
       if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
-  return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  ShapedType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
+  ShapedType stripMinedType;
+  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+    stripMinedType =
+        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+    stripMinedType =
+        MemRefType::get(stripMinedShape, memrefType.getElementType());
+  }
+  ShapedType collapsedType;
+  if (stripMinedType.isa<TensorType>()) {
+    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+        cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+  } else if (stripMinedType.isa<MemRefType>()) {
+    collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+        cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+  }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedTensorType.getElementType());
+      loc, dims, stripMinedType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+                                    ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.slice(currentDim, dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamic))
+      size = ShapedType::kDynamic;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+  return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+    MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+  return inferCollapsedType(
+      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                type.getContext(), reassociation)));
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                            ...
[truncated]

@ita9naiwa ita9naiwa changed the title [draft][mlir][linalg] pack, unpack to take memref inputs [MLIR][Linalg] pack, unpack to take memref inputs Mar 1, 2025
@rengolin rengolin requested review from adam-smnk and chelini March 1, 2025 05:59
@adam-smnk
Copy link
Contributor

I expect most of the existing linalg.pack/unpack transforms can't handle memrefs. It'd be great to restrict them to ops with pure tensor semantics.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, here is the other round of review comments. I think we need to make some changes in verifier as well.

@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is not a relevant change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens every time I run clang-format, this is not relevant, I'll remove this change when final commit before merge. is it fine?

@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
packOp.getResult().setType(cast<ShapedType>(dest.getType()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this snippet is wrong after we add the memref semantics to pack ops. Because

  1. There are no results in memref. Updating the type is incorrect.
  2. The above comment states that it inserts tensor.cast op, while we need memref.cast op for memref version.

Thus, I suggest to kick in the canonicalization pattern only when the pack op has tensor semantics. Otherwise, the code is wrong to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now tensor version, memref version works well.

I think it's good idea to put these into test cases; how do you think?

module {
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
  %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
               into %x : memref<2x3xf32> -> memref<2x3xf32>
  %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : memref<2x3xf32> -> memref<2x3xf32>
  return %packed : memref<2x3xf32>
}
}

is canonicalized into

module {
  func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>) -> memref<2x3xf32> {
    %unpack = linalg.unpack %arg0 inner_dims_pos = [] inner_tiles = [] into %arg0 : memref<2x3xf32> -> memref<2x3xf32>
    return %arg0 : memref<2x3xf32>
  }
}
module {
func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
  %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
  %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
  return %packed : tensor<2x3xf32>
}
}

reduces into

basemlir-opt --canonicalize --cse cano-tensor.mlir
module {
  func.func @fold_pack_unpack_tensor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
    return %arg0 : tensor<2x3xf32>
  }
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, please add the tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/canonicalize.mlir

EDIT: You should have inner_dims_pos in the fold_pack_unpack_memref test. (We can have canonicalization patterns to fold them away if the configuration is empty and the types statically match.)

@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedType(RankedTensorType sourceType,
static RankedTensorType inferPackedType(ShapedType sourceType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks more like a inferPackedTensorType to me now. Because it always returns a RankedTensorType. In the LLVM codebase, it is used in op verification for shapes and the data-layout propagation pass. IMO, we can leave it as what it is and update the verifier.

We can add a new inferShape to the pack op and use it in the verifier.

// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
return op->emitError("the shape of output is not large enough to hold the "
"packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
}

On the other hand, please revisit the verifier. I'd expect some changes there. E.g., we do not support unranked tensor/memref in Linalg ops, at least for these pack/unpack ops. Without your change, it emits an error like custom op 'linalg.pack' invalid kind of type specified. With your change, it crashes. This is because we switch to AnyShaped, and there is no longer "hasRank" trait in the op definition.

To repro: mlir-opt repro.mlir

func.func @pack_tensor(%source: tensor<*xf32>, %dest: tensor<*xf32>) -> tensor<*xf32>{
  %0 = linalg.pack %source outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
      into %dest : tensor<*xf32> -> tensor<*xf32>                                                                                                                                                                                                              return %0 : tensor<*xf32>
}

IMO, we should check that the types are ranked, and the shapes are compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Verify that the source and destination are ranked types.
  if (!packOrUnPack.getSourceType().hasRank() ||
      !packOrUnPack.getDestType().hasRank()) {
    return op->emitError("expected both source and destination to be shaped types");
  }

I think we can put this,

@@ -706,3 +706,21 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>

// -----
// Test that we can lower all the way to LLVM without crashing, don't check results here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the comment is copied from other places, maybe they are obsolete. The comment was added with lowering linalg to llvm. I think the modern convention is that roundtrips/ops tests are verifying printers and parsers. The lowering should be moved to other lit test files. Thus we don't need such comment.

// Test that we can lower all the way to LLVM without crashing, don't check results here.
// DISABLED: mlir-opt %s -o=/dev/null 2>&1
func.func @views(%arg0: index) {
%c0 = arith.constant 0 : index
%0 = arith.muli %arg0, %arg0 : index
%1 = memref.alloc (%0) : memref<?xi8>
%3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
%4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
memref.dealloc %1 : memref<?xi8>
return
}
// CHECK-LABEL: func @views
// CHECK: arith.muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8>
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
// CHECK-SAME: memref<?xi8> to memref<?x?xf32>
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] :
// CHECK-SAME: memref<?xi8> to memref<?x?xvector<4x4xf32>>
// CHECK-NEXT: memref.dealloc %{{.*}} : memref<?xi8>

btw, we don't need any return in the memref tests.

@ita9naiwa
Copy link
Contributor Author

Hi @hanhanW , I mostly addressed your review comments. canonicalization pattern works for Tensor and memref.

I want to know how to track down all the patterns related to Tensor Pack/UnPack Ops

@ita9naiwa ita9naiwa requested a review from hanhanW March 29, 2025 23:52
@ita9naiwa
Copy link
Contributor Author

I want to know how to track down all the patterns related to Tensor Pack/UnPack Ops

I tracked down with grep matching linalg::PackOp and linalg::UnPackOp.

I bailed out transformations and rewrite patterns using e.g.,

// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
  if (!packOp.hasPureTensorSemantics()) {
    return failure();
  }

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better to me, thanks! RE TODO comments: I only expect that they are available for PackAndUnpackPatterns, Vectorization, and canonicalization. Other transform might not able to handle the memref case atm. I suggest removing the TODO from other files, and people can start their projects later if they need it.

// Insert tensor.cast ops if static shape inference is available..
// Insert either tensor.cast or memref.cast ops
// if static shape inference is available..
bool hasTensorSemantics = packOp.hasPureTensorSemantics();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is only used in the below closure, let's move it into the if-body. Also, we are missing tests if we add such support in the PR. E.g.,

// -----
func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %src
padding_value(%cst : f32)
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
return %pack : tensor<10x20x30x40x16xf32>
}
// CHECK-LABEL: func.func @infer_src_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
// CHECK: %[[PACK:.+]] = linalg.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
// CHECK: return %[[PACK]]
// -----
func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %src
padding_value(%cst : f32)
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
return %pack : tensor<?x?x?x?x16xf32>
}
// CHECK-LABEL: func.func @infer_dest_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
// CHECK: return %[[CAST_PACK]]

Comment on lines 5027 to 5029
if (hasTensorSemantics)
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three types in the pack ops on tensors. (1) source type (2) dest type (3) result type.

In the shape inference, we need casting for (1) and (2), so here you also need to take memref into account. (A new test will capture the failure). For (3), where is updated in the modifyOpInPlace{...}, we update the result type if and only if it is on tensors.

The (3) only happens on tensors because memref variant only has (1) and (2) types.

@ita9naiwa ita9naiwa requested a review from hanhanW April 2, 2025 22:13
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing most of the comments. I think we are close to land the PR. I left some comments about TODO and comments, and I think we are missing some tests for canonicalization patterns. Please add tests to reflect the changes.

@@ -78,7 +78,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
omp::FlushOp, omp::MapBoundsOp,
omp::ThreadprivateOp>::value) {
if (isa<MemRefType>(originalOperand.getType())) {
// TODO: Support memref type in variable operands
// TODO: Support Memref PackOp. Temporarily return failure.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the change. I think you are not intended to update this. :)

Comment on lines 1678 to 1679


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: delete one blank line

Comment on lines 5039 to 5047
if (hasTensorSemantics) {
auto castOp =
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
} else {
auto castOp =
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so.

/// a way that ensures that they agree on which dimensions are dynamic.
/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
/// of the packed type. Having a shared helper helps implement these two methods
/// in a way that ensures that they agree on which dimensions are dynamic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should replace the function with the new inferPackedShape method. I don't see the value of having an indirect call. I.e.,

static SmallVector<int64_t> getPackOpResultTypeShape(
    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {

can become

SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
                                              ArrayRef<int64_t> innerTileSizes,
                                              ArrayRef<int64_t> innerDimsPos,
                                              ArrayRef<int64_t> outerDimsPerm) {

Please remember to add a similar comment to td, like it helps ensure all the shape inference methods agree on which dimensions are dynamic.

Comment on lines 379 to 381
if (!packOp.hasPureTensorSemantics())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a redundant check, as all the precondition checks happen in matchAndRewrite methods. Can you remove it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all the TODOs from this file. I think they are not trivial, because we don't have memref.pad op. No one will clear the TODO in this case.

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include <iostream>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IWYU, please delete the include.

Comment on lines +5426 to +5427
if (!op.hasPureTensorSemantics())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a TODO for consistency. It is a reasonable folder to me and we should support it (in follow-ups).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unpack, not pack.

@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
packOp.getResult().setType(cast<ShapedType>(dest.getType()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, please add the tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/canonicalize.mlir

EDIT: You should have inner_dims_pos in the fold_pack_unpack_memref test. (We can have canonicalization patterns to fold them away if the configuration is empty and the types statically match.)

@ita9naiwa
Copy link
Contributor Author

I accidently clang-formatted invalid.mlir, I will fix very soon.

@ita9naiwa ita9naiwa requested a review from hanhanW April 14, 2025 10:25
Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
@ita9naiwa ita9naiwa force-pushed the ita9naiwa/pack-memref branch from e660c40 to 17ad838 Compare April 17, 2025 07:53
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The canonicalization pattern is off, which leads to an invalid IR. Other parts look good, just some nits about comments.

Comment on lines 1902 to 1904
if (!unpackOp.hasPureTensorSemantics()) {
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to the PR. Can you revert the change?

Comment on lines 1033 to 1036
// TODO: Support Memref PackOp. Temporarily return just Op Source.
if (!packOp.hasPureTensorSemantics())
return input;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not necessary to me because it is a local function, and we already checked it in DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite. I think the transformations are built around tensor variant. I'd not put a TODO for memref unless we have a plan for it.

@@ -265,6 +268,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
highs[pos] = affine::makeComposedFoldedAffineApply(
rewriter, loc, map, {outerSize, origSize, innerSize});
}
// TODO: Need memref.pad operation to support memref operands
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is memref.pad? I think this will be a question from other contributors in the future. I'd remove the comment to avoid confusion in the first place.

Comment on lines +5426 to +5427
if (!op.hasPureTensorSemantics())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unpack, not pack.

Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com>
Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
@ita9naiwa ita9naiwa force-pushed the ita9naiwa/pack-memref branch from 7b86f9b to 2aca3fd Compare April 20, 2025 04:55
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I made a mistake in the review about the folding. After taking a look at the real IR tests, I think there are two issues.

  1. The memref version should not return a memref. By definition, the op performs packing from the source buffer and store the result to the destination buffer. Like other linalg operations, it should not return any value when it has buffer semantics. In other linalg ops, the Variadic<AnyRankedTensor> adds the check. It was pointed out in a comment from @adam-smnk, and I think it is not resolved.

Invalid case:

%packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32>

It should be like this valid case:

linalg.pack %unpacked
  inner_dims_pos = [0, 1] inner_tiles = [4, 4]
  into %buf_packed
  : memref<40x80xf32> -> memref<10x20x4x4xf32>
  1. The second issue is about the folding on memrefs. TLDR is that we should disable it. Because the current logic is not correct. How I'd implement the folding is that we need to look at the source memref and other uses of the source memref. It could be expensive and the logic is more complicated on memrefs. The below is not foldable if there are other ops using %buf_unpacked in the middle. Things become way more complicated when control flow is involved. I think this kind of folding should be implemented in a pass manner. (Please remember to remove the test cases in canonicalzation.mlir if they are disabled. I also did not see the reason why you added some folding tests to tensor, while the PR is scoped to add the support for memref. Maybe we can drop the tests for tensors as well?)
linalg.unpack %t
  inner_dims_pos = [0, 1] inner_tiles = [4, 4]
  into %buf_unpacked
  : memref<10x20x4x4xf32> -> memref<40x80xf32>
linalg.pack %unpacked
  inner_dims_pos = [0, 1] inner_tiles = [4, 4]
  into %buf_packed
  : memref<40x80xf32> -> memref<10x20x4x4xf32>

(The rest issue is about reverting non-related code changes, I think you mentioned that you will remove them before landing the PR.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants