Skip to content

[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 80 additions & 11 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() {
return success();
}

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I intend to change this in a follow-up PR.


/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
bool isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
auto isNonScalableUnitDim = [&](int64_t dim) {
return inShape[dim] == 1 && !inDimIsScalable[dim];
};
int64_t current = 0;
for (auto p : permutation) {
if (!isNonScalableUnitDim(p)) {
if (p < current) {
return false;
}
current = p;
}
}
return true;
}

} // namespace

OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

VectorType resultType = getType();
Expand All @@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
if (getSource().getType() == resultType)
return getSource();

// Y = shape_cast(shape_cast(X)))
// -> X, if X and Y have same type
// -> shape_cast(X) otherwise.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
setOperand(otherOp.getSource());
// shape_cast(shape_cast(x)) -> shape_cast(x)
if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
setOperand(precedingShapeCast.getSource());
return getResult();
}

// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
// This folder does
// shape_cast(transpose) -> shape_cast
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
// shape_cast -> shape_cast(transpose)
// i.e. the complete opposite. When paired, these 2 patterns can cause
// infinite cycles in pattern rewriting.
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
// vectors, so by disabling this folder for scalable vectors the
// cycle is avoided.
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're illegal according to an arm specific lowering target, which I'm not familiar with

struct ConvertIllegalShapeCastOpsToTransposes

I think @banach-space suspects that actually they're not illegal and will investigate the removal of this constraint here (and the pattern ConvertIllegalShapeCastOpsToTransposes).

Copy link
Contributor

@banach-space banach-space May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Illegal" in this context means "Illegal as an SME target - transform before invoking -convert-vector-arm-sme".

Tidying this up is high on my TODO list (will look into it this week).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like ConvertIllegalShapeCastOpsToTransposes is no longer required:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 thanks @banach-space

// still needed. If it's not, then we can fold here.
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
return {};
}

// Y = shape_cast(broadcast(X))
// -> X, if X and Y have same type
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
Expand All @@ -5619,7 +5662,7 @@ namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
/// vector<4x1x1xi1> --> vector<4x1>
/// vector<4x1x1xi1> --> vector<4x1xi1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
Expand Down Expand Up @@ -6086,6 +6129,32 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(shape_cast) into a new shape_cast.
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto shapeCastOp =
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return failure();
if (!isOrderPreserving(transposeOp))
return failure();

VectorType resultType = transposeOp.getType();

// We don't need to check isValidShapeCast at this point, because it is
// guaranteed that merging the transpose into the the shape_cast is a valid
// shape_cast, because the transpose just inserts/removes ones.

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
shapeCastOp.getSource());
return success();
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
Expand Down Expand Up @@ -6184,8 +6253,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
FoldTransposeBroadcast>(context);
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
FoldTransposeSplat, FoldTransposeBroadcast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}

// -----

// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
Expand Down Expand Up @@ -3061,7 +3062,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
return %1 : vector<4x8xf32>
}


// -----

// CHECK-LABEL: @insert_scalar_poison_idx
Expand Down
110 changes: 110 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,113 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
return %1 : vector<3x3x3xi8>
}


// -----

// Test of FoldTransposeShapeCast
// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
// 1 -> 0
// 2 -> 4
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
// CHECK-LABEL: @transpose_shape_cast
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
%0 = vector.transpose %arg, [1, 0, 3, 4, 2]
: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
%1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// Test of FoldTransposeShapeCast
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
// As this is not increasing (2 > 1), this transpose is not order
// preserving and cannot be treated as a shape_cast.
// CHECK-LABEL: @negative_transpose_shape_cast
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
%0 = vector.transpose %arg, [0, 2, 1, 3]
: vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
%1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// Test of FoldTransposeShapeCast
// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
// CHECK-LABEL: @negative_transpose_shape_cast_scalable
// CHECK: vector.transpose
// CHECK: vector.shape_cast
func.func @negative_transpose_shape_cast_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
%0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
%1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
return %1 : vector<[4]xi8>
}

// -----

// Test of shape_cast folding.
// The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
// vectors.
// CHECK-LABEL: @shape_cast_transpose_scalable
// CHECK: vector.shape_cast
// CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
func.func @shape_cast_transpose_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
%0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
%1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
return %1 : vector<[4]x1xi8>
}

// -----

// Test of shape_cast folding.
// A transpose that is 'order preserving' can be treated like a shape_cast.
// CHECK-LABEL: @shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
%0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
%1 = vector.transpose %0, [0, 2, 1]
: vector<6x1x1xi8> to vector<6x1x1xi8>
return %1 : vector<6x1x1xi8>
}

// -----

// Test of shape_cast folding.
// Scalable dimensions should be treated as non-unit dimensions.
// CHECK-LABEL: @shape_cast_transpose_scalable
// CHECK: vector.shape_cast
// CHECK: vector.transpose
func.func @shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
}

// -----

// Test of shape_cast (not) folding.
// CHECK-LABEL: @negative_shape_cast_transpose
// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
%0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
return %1 : vector<2x3xi8>
}