-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[MLIR][Linalg] pack, unpack to take memref inputs #129036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir-memref Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
I expect most of the existing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, here is the other round of review comments. I think we need to make some changes in verifier as well.
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices); | |||
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : | |||
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> | |||
/// | |||
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : | |||
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is not a relevant change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This happens every time I run clang-format
, this is not relevant, I'll remove this change when final commit before merge. is it fine?
@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |||
rewriter.modifyOpInPlace(packOp, [&] { | |||
packOp.getSourceMutable().assign(source); | |||
packOp.getDestMutable().assign(dest); | |||
packOp.getResult().setType(cast<RankedTensorType>(dest.getType())); | |||
packOp.getResult().setType(cast<ShapedType>(dest.getType())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that this snippet is wrong after we add the memref semantics to pack ops. Because
- There are no results in memref. Updating the type is incorrect.
- The above comment states that it inserts
tensor.cast
op, while we needmemref.cast
op for memref version.
Thus, I suggest to kick in the canonicalization pattern only when the pack op has tensor semantics. Otherwise, the code is wrong to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now tensor version, memref version works well.
I think it's good idea to put these into test cases; how do you think?
module {
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : memref<2x3xf32> -> memref<2x3xf32>
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : memref<2x3xf32> -> memref<2x3xf32>
return %packed : memref<2x3xf32>
}
}
is canonicalized into
module {
func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>) -> memref<2x3xf32> {
%unpack = linalg.unpack %arg0 inner_dims_pos = [] inner_tiles = [] into %arg0 : memref<2x3xf32> -> memref<2x3xf32>
return %arg0 : memref<2x3xf32>
}
}
module {
func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
return %packed : tensor<2x3xf32>
}
}
reduces into
base ❯ mlir-opt --canonicalize --cse cano-tensor.mlir
module {
func.func @fold_pack_unpack_tensor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
return %arg0 : tensor<2x3xf32>
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, please add the tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/canonicalize.mlir
EDIT: You should have inner_dims_pos in the fold_pack_unpack_memref
test. (We can have canonicalization patterns to fold them away if the configuration is empty and the types statically match.)
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ | |||
// Method to get the `RankedTensorType` of the result based on the inner | |||
// tiles, position of the inner tiles (innerDimsPos) and interchange vector | |||
// of outer loops (outerDimsPerm). | |||
static RankedTensorType inferPackedType(RankedTensorType sourceType, | |||
static RankedTensorType inferPackedType(ShapedType sourceType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks more like a inferPackedTensorType
to me now. Because it always returns a RankedTensorType
. In the LLVM codebase, it is used in op verification for shapes and the data-layout propagation pass. IMO, we can leave it as what it is and update the verifier.
We can add a new inferShape
to the pack op and use it in the verifier.
llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Lines 4471 to 4480 in 4c4e4e4
// Verify result shape is greater than the minimum expected | |
// by the pack operation, and that the output shape | |
// represents full tiles. | |
RankedTensorType expectedPackedType = PackOp::inferPackedType( | |
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); | |
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { | |
return op->emitError("the shape of output is not large enough to hold the " | |
"packed data. Expected at least ") | |
<< expectedPackedType << ", got " << packedType; | |
} |
On the other hand, please revisit the verifier. I'd expect some changes there. E.g., we do not support unranked tensor/memref in Linalg ops, at least for these pack/unpack ops. Without your change, it emits an error like custom op 'linalg.pack' invalid kind of type specified
. With your change, it crashes. This is because we switch to AnyShaped
, and there is no longer "hasRank" trait in the op definition.
To repro: mlir-opt repro.mlir
func.func @pack_tensor(%source: tensor<*xf32>, %dest: tensor<*xf32>) -> tensor<*xf32>{
%0 = linalg.pack %source outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
into %dest : tensor<*xf32> -> tensor<*xf32> return %0 : tensor<*xf32>
}
IMO, we should check that the types are ranked, and the shapes are compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Verify that the source and destination are ranked types.
if (!packOrUnPack.getSourceType().hasRank() ||
!packOrUnPack.getDestType().hasRank()) {
return op->emitError("expected both source and destination to be shaped types");
}
I think we can put this,
@@ -706,3 +706,21 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: | |||
// CHECK-LABEL: func @conv2d_channel_first_q_promote( | |||
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) | |||
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> | |||
|
|||
// ----- | |||
// Test that we can lower all the way to LLVM without crashing, don't check results here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the comment is copied from other places, maybe they are obsolete. The comment was added with lowering linalg to llvm. I think the modern convention is that roundtrips/ops
tests are verifying printers and parsers. The lowering should be moved to other lit test files. Thus we don't need such comment.
llvm-project/mlir/test/Dialect/Linalg/roundtrip.mlir
Lines 6 to 25 in b3f01a6
// Test that we can lower all the way to LLVM without crashing, don't check results here. | |
// DISABLED: mlir-opt %s -o=/dev/null 2>&1 | |
func.func @views(%arg0: index) { | |
%c0 = arith.constant 0 : index | |
%0 = arith.muli %arg0, %arg0 : index | |
%1 = memref.alloc (%0) : memref<?xi8> | |
%3 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32> | |
%4 = memref.view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>> | |
memref.dealloc %1 : memref<?xi8> | |
return | |
} | |
// CHECK-LABEL: func @views | |
// CHECK: arith.muli %{{.*}}, %{{.*}} : index | |
// CHECK-NEXT: memref.alloc(%{{.*}}) : memref<?xi8> | |
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] : | |
// CHECK-SAME: memref<?xi8> to memref<?x?xf32> | |
// CHECK-NEXT: memref.view %{{.*}}[%{{.*}}][%{{.*}}] : | |
// CHECK-SAME: memref<?xi8> to memref<?x?xvector<4x4xf32>> | |
// CHECK-NEXT: memref.dealloc %{{.*}} : memref<?xi8> |
btw, we don't need any return in the memref tests.
Hi @hanhanW , I mostly addressed your review comments. canonicalization pattern works for Tensor and memref. I want to know how to track down all the patterns related to Tensor Pack/UnPack Ops |
I tracked down with grep matching I bailed out transformations and rewrite patterns using e.g., // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
if (!packOp.hasPureTensorSemantics()) {
return failure();
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better to me, thanks! RE TODO comments: I only expect that they are available for PackAndUnpackPatterns, Vectorization, and canonicalization. Other transform might not able to handle the memref case atm. I suggest removing the TODO from other files, and people can start their projects later if they need it.
// Insert tensor.cast ops if static shape inference is available.. | ||
// Insert either tensor.cast or memref.cast ops | ||
// if static shape inference is available.. | ||
bool hasTensorSemantics = packOp.hasPureTensorSemantics(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is only used in the below closure, let's move it into the if-body. Also, we are missing tests if we add such support in the PR. E.g.,
llvm-project/mlir/test/Dialect/Linalg/canonicalize.mlir
Lines 1343 to 1380 in e55164a
// ----- | |
func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> { | |
%cst = arith.constant 0.000000e+00 : f32 | |
%pack = linalg.pack %src | |
padding_value(%cst : f32) | |
outer_dims_perm = [2, 1, 3, 0] | |
inner_dims_pos = [2] | |
inner_tiles = [16] | |
into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32> | |
return %pack : tensor<10x20x30x40x16xf32> | |
} | |
// CHECK-LABEL: func.func @infer_src_shape_pack | |
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] | |
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] | |
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32> | |
// CHECK: %[[PACK:.+]] = linalg.pack %[[CAST_SRC]] {{.+}} into %[[DEST]] | |
// CHECK: return %[[PACK]] | |
// ----- | |
func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> { | |
%cst = arith.constant 0.000000e+00 : f32 | |
%pack = linalg.pack %src | |
padding_value(%cst : f32) | |
outer_dims_perm = [2, 1, 3, 0] | |
inner_dims_pos = [2] | |
inner_tiles = [16] | |
into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32> | |
return %pack : tensor<?x?x?x?x16xf32> | |
} | |
// CHECK-LABEL: func.func @infer_dest_shape_pack | |
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] | |
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] | |
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32> | |
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]] {{.+}} into %[[CAST_DEST]] | |
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32> | |
// CHECK: return %[[CAST_PACK]] |
if (hasTensorSemantics) | ||
dest = | ||
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are three types in the pack ops on tensors. (1) source type (2) dest type (3) result type.
In the shape inference, we need casting for (1) and (2), so here you also need to take memref into account. (A new test will capture the failure). For (3), where is updated in the modifyOpInPlace{...}
, we update the result type if and only if it is on tensors.
The (3) only happens on tensors because memref variant only has (1) and (2) types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing most of the comments. I think we are close to land the PR. I left some comments about TODO and comments, and I think we are missing some tests for canonicalization patterns. Please add tests to reflect the changes.
@@ -78,7 +78,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { | |||
omp::FlushOp, omp::MapBoundsOp, | |||
omp::ThreadprivateOp>::value) { | |||
if (isa<MemRefType>(originalOperand.getType())) { | |||
// TODO: Support memref type in variable operands | |||
// TODO: Support Memref PackOp. Temporarily return failure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the change. I think you are not intended to update this. :)
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: delete one blank line
if (hasTensorSemantics) { | ||
auto castOp = | ||
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); | ||
rewriter.replaceAllUsesExcept(packOp, castOp, castOp); | ||
} else { | ||
auto castOp = | ||
rewriter.create<memref::CastOp>(loc, originalResultType, packOp); | ||
rewriter.replaceAllUsesExcept(packOp, castOp, castOp); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so.
/// a way that ensures that they agree on which dimensions are dynamic. | ||
/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape | ||
/// of the packed type. Having a shared helper helps implement these two methods | ||
/// in a way that ensures that they agree on which dimensions are dynamic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should replace the function with the new inferPackedShape
method. I don't see the value of having an indirect call. I.e.,
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
can become
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
Please remember to add a similar comment to td
, like it helps ensure all the shape inference methods agree on which dimensions are dynamic.
if (!packOp.hasPureTensorSemantics()) | ||
return failure(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a redundant check, as all the precondition checks happen in matchAndRewrite
methods. Can you remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all the TODOs from this file. I think they are not trivial, because we don't have memref.pad
op. No one will clear the TODO in this case.
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IWYU, please delete the include.
if (!op.hasPureTensorSemantics()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a TODO for consistency. It is a reasonable folder to me and we should support it (in follow-ups).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unpack, not pack.
@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |||
rewriter.modifyOpInPlace(packOp, [&] { | |||
packOp.getSourceMutable().assign(source); | |||
packOp.getDestMutable().assign(dest); | |||
packOp.getResult().setType(cast<RankedTensorType>(dest.getType())); | |||
packOp.getResult().setType(cast<ShapedType>(dest.getType())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, please add the tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/canonicalize.mlir
EDIT: You should have inner_dims_pos in the fold_pack_unpack_memref
test. (We can have canonicalization patterns to fold them away if the configuration is empty and the types statically match.)
I accidently clang-formatted invalid.mlir, I will fix very soon. |
Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
e660c40
to
17ad838
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The canonicalization pattern is off, which leads to an invalid IR. Other parts look good, just some nits about comments.
if (!unpackOp.hasPureTensorSemantics()) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bump
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to the PR. Can you revert the change?
// TODO: Support Memref PackOp. Temporarily return just Op Source. | ||
if (!packOp.hasPureTensorSemantics()) | ||
return input; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is not necessary to me because it is a local function, and we already checked it in DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite
. I think the transformations are built around tensor variant. I'd not put a TODO for memref unless we have a plan for it.
@@ -265,6 +268,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, | |||
highs[pos] = affine::makeComposedFoldedAffineApply( | |||
rewriter, loc, map, {outerSize, origSize, innerSize}); | |||
} | |||
// TODO: Need memref.pad operation to support memref operands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is memref.pad
? I think this will be a question from other contributors in the future. I'd remove the comment to avoid confusion in the first place.
if (!op.hasPureTensorSemantics()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unpack, not pack.
Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com> Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
7b86f9b
to
2aca3fd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I made a mistake in the review about the folding. After taking a look at the real IR tests, I think there are two issues.
- The memref version should not return a memref. By definition, the op performs packing from the source buffer and store the result to the destination buffer. Like other linalg operations, it should not return any value when it has buffer semantics. In other linalg ops, the
Variadic<AnyRankedTensor>
adds the check. It was pointed out in a comment from @adam-smnk, and I think it is not resolved.
Invalid case:
%packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32>
It should be like this valid case:
linalg.pack %unpacked
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_packed
: memref<40x80xf32> -> memref<10x20x4x4xf32>
- The second issue is about the folding on memrefs. TLDR is that we should disable it. Because the current logic is not correct. How I'd implement the folding is that we need to look at the source memref and other uses of the source memref. It could be expensive and the logic is more complicated on memrefs. The below is not foldable if there are other ops using
%buf_unpacked
in the middle. Things become way more complicated when control flow is involved. I think this kind of folding should be implemented in a pass manner. (Please remember to remove the test cases incanonicalzation.mlir
if they are disabled. I also did not see the reason why you added some folding tests to tensor, while the PR is scoped to add the support for memref. Maybe we can drop the tests for tensors as well?)
linalg.unpack %t
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_unpacked
: memref<10x20x4x4xf32> -> memref<40x80xf32>
linalg.pack %unpacked
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_packed
: memref<40x80xf32> -> memref<10x20x4x4xf32>
(The rest issue is about reverting non-related code changes, I think you mentioned that you will remove them before landing the PR.)
#129004
ShapedType
, notRankedTensorType
MemrefType
andTensorType
MemrefType