Skip to content

[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 80 additions & 10 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5594,6 +5594,29 @@ LogicalResult ShapeCastOp::verify() {
return success();
}

namespace {

/// Return true if `transpose` does not permute a pair of dimensions that are
/// both not of size 1. By `order preserving` we mean that the flattened
/// versions of the input and output vectors are (numerically) identical.
/// In other words `transpose` is effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
int64_t current = 0;
for (auto p : permutation) {
if (inShape[p] != 1) {
if (p < current) {
return false;
}
current = p;
}
}
return true;
}

} // namespace

OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

// No-op shape cast.
Expand All @@ -5602,13 +5625,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

VectorType resultType = getType();

// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {

// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
// shape_cast(something(x)) -> x, or
// -> shape_cast(x).
//
// Confirms that a new shape_cast will have valid semantics (expands OR
// collapses dimensions).
auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
VectorType srcType = source.getType();
if (resultType == srcType)
return otherOp.getSource();
return source;
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
Expand All @@ -5618,8 +5643,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
setOperand(otherOp.getSource());
setOperand(source);
return getResult();
};

// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
TypedValue<VectorType> source = otherOp.getSource();
return maybeFold(source);
}

// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
if (transpose.getType().isScalable())
return {};
if (isOrderPreserving(transpose)) {
TypedValue<VectorType> source = transpose.getVector();
return maybeFold(source);
}
return {};
}

// Cancelling broadcast and shape cast ops.
Expand All @@ -5646,7 +5688,7 @@ namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
/// vector<4x1x1xi1> --> vector<4x1>
/// vector<4x1x1xi1> --> vector<4x1xi1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
Expand Down Expand Up @@ -6113,6 +6155,34 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(shape_cast) into a new shape_cast.
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto shapeCastOp =
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return failure();
if (!isOrderPreserving(transposeOp))
return failure();
if (transposeOp.getType().isScalable())
return failure();

VectorType resultType = transposeOp.getType();

// We don't need to check isValidShapeCast at this point, because it is
// guaranteed that merging the transpose into the the shape_cast is a valid
// shape_cast, because the transpose just inserts/removes ones.

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
shapeCastOp.getSource());
return success();
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
Expand Down Expand Up @@ -6211,8 +6281,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
FoldTransposeBroadcast>(context);
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
FoldTransposeSplat, FoldTransposeBroadcast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}

// -----

// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
Expand Down Expand Up @@ -3035,7 +3036,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
return %1 : vector<4x8xf32>
}


// -----

// CHECK-LABEL: @insert_scalar_poison_idx
Expand Down
64 changes: 64 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,67 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
return %1 : vector<3x3x3xi8>
}


// -----

// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
// 1 -> 0
// 2 -> 4
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
// CHECK-LABEL: @transpose_shape_cast
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
%0 = vector.transpose %arg, [1, 0, 3, 4, 2]
: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
%1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// In this test, the mapping of non-one indices (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
// As this is not increasing (2 > 1), this transpose is not order
// preserving and cannot be treated as a shape_cast.
// CHECK-LABEL: @negative_transpose_shape_cast
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
%0 = vector.transpose %arg, [0, 2, 1, 3]
: vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
%1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// CHECK-LABEL: @shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
%0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
%1 = vector.transpose %0, [0, 2, 1]
: vector<6x1x1xi8> to vector<6x1x1xi8>
return %1 : vector<6x1x1xi8>
}

// -----

// CHECK-LABEL: @negative_shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
%0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
return %1 : vector<2x3xi8>
}
Loading