Skip to content

Commit 2aca3fd

Browse files
ita9naiwahanhanW
andcommitted
Update mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com> Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
1 parent 17ad838 commit 2aca3fd

File tree

6 files changed

+16
-42
lines changed

6 files changed

+16
-42
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

+1-12
Original file line numberDiff line numberDiff line change
@@ -190,29 +190,18 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
190190
// Method to get the `RankedTensorType` of the result based on the inner
191191
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
192192
// of outer loops (outerDimsPerm).
193-
/// This method uses inferPackedShape to ensure consistency with other shape
194-
/// inference methods regarding which dimensions are dynamic.
195193
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
196194
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
197195
ArrayRef<int64_t> outerDimsPerm = {});
198196

199197
// Method to get the `MemRefType` of the result based on the inner
200198
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
201199
// of outer loops (outerDimsPerm).
202-
/// This method uses inferPackedShape to ensure consistency with other shape
203-
/// inference methods regarding which dimensions are dynamic.
204200
static MemRefType inferPackedMemRefType(MemRefType sourceType,
205201
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
206202
ArrayRef<int64_t> outerDimsPerm = {});
207203

208-
// Method to get the Shape of the result based on the input shape, inner
209-
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
210-
// of outer loops (outerDimsPerm).
211-
212-
/// Helper for PackOp::{getResultShape, inferPackedTensorType, inferPackedMemRefType}.
213-
/// Returns the shape of the packed type. Having a shared helper helps
214-
/// implement these three methods in a way that ensures
215-
/// that they agree on which dimensions are dynamic.
204+
// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
216205
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
217206
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
218207
ArrayRef<int64_t> outerDimsPerm = {});

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

+7-15
Original file line numberDiff line numberDiff line change
@@ -5006,7 +5006,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50065006

50075007
// Insert tensor.cast ops if static shape inference is available..
50085008
SmallVector<int64_t> srcShape, destShape;
5009-
if (inferStaticShape(packOp, srcShape, destShape)) {
5009+
if (inferStaticShape(packOp, srcShape, destShape) &&
5010+
packOp.hasPureTensorSemantics()) {
50105011
Location loc = packOp.getLoc();
50115012
Value source = packOp.getSource();
50125013
if (srcShape != packOp.getSourceType().getShape()) {
@@ -5030,20 +5031,11 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50305031
// Insert a cast if needed
50315032
if (needUpdateDestType) {
50325033
rewriter.setInsertionPointAfter(packOp);
5033-
Operation *castOp;
5034-
bool hasTensorSemantics = packOp.hasPureTensorSemantics();
5035-
if (hasTensorSemantics) {
5036-
castOp =
5037-
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5038-
} else {
5039-
castOp =
5040-
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
5041-
}
5042-
rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
5043-
} else {
5044-
// TODO: support memref.cast if static shape inference is available.
5045-
return failure();
5034+
auto castOp =
5035+
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5036+
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
50465037
}
5038+
50475039
return success();
50485040
}
50495041

@@ -5424,7 +5416,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
54245416
if (!tensor::hasFoldableTensorCastOperand(op))
54255417
return failure();
54265418

5427-
// TODO: Support Memref PackOp. Temporarily return failure.
5419+
// TODO: Support Memref UnPackOp. Temporarily return failure.
54285420
if (!op.hasPureTensorSemantics())
54295421
return failure();
54305422

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
268268
highs[pos] = affine::makeComposedFoldedAffineApply(
269269
rewriter, loc, map, {outerSize, origSize, innerSize});
270270
}
271-
// TODO: Need memref.pad operation to support memref operands
271+
272272
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
273273
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
274274
packingMetadata.reassociations);
@@ -1030,9 +1030,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10301030
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10311031
linalg::PackOp packOp) {
10321032
Value input = packOp.getSource();
1033-
// TODO: Support Memref PackOp. Temporarily return just Op Source.
1034-
if (!packOp.hasPureTensorSemantics())
1035-
return input;
10361033

10371034
if (!packOp.getPaddingValue()) {
10381035
return input;

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1889,9 +1889,8 @@ static LogicalResult
18891889
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
18901890
ArrayRef<int64_t> inputVectorSizes) {
18911891
// TODO: Support Memref PackOp. Temporarily return failure.
1892-
if (!unpackOp.hasPureTensorSemantics()) {
1892+
if (!unpackOp.hasPureTensorSemantics())
18931893
return failure();
1894-
}
18951894

18961895
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
18971896
return !getConstantIntValue(res).has_value();

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
315315
// have proven that these are not sliced. In this case we just take
316316
// the full extent of each dimension in the reassociation list.
317317
if (linearizedDimensions[it.index()]) {
318-
llvm::append_range(
319-
offsetsSizesAndStrides,
320-
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321-
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322-
}));
323-
318+
llvm::append_range(offsetsSizesAndStrides,
319+
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
320+
return {zeroAttr, collapseShapeInputShape[idx],
321+
oneAttr};
322+
}));
324323
continue;
325324
}
326325

mlir/test/Dialect/Linalg/canonicalize.mlir

+1-3
Original file line numberDiff line numberDiff line change
@@ -1894,9 +1894,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
18941894
return %unpack : tensor<7x?xi32>
18951895
}
18961896

1897-
//===----------------------------------------------------------------------===//
1898-
// linalg.unpack + linalg.pack
1899-
//===----------------------------------------------------------------------===//
1897+
// -----
19001898

19011899
// CHECK-LABEL: func.func @fold_pack_unpack_tensor
19021900
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>

0 commit comments

Comments
 (0)