Skip to content

[mlir][vector] transpose(broadcast) -> broadcast canonicalization #135096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 96 additions & 24 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6085,28 +6085,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};

// Folds transpose(broadcast(<scalar>)) into broadcast(<scalar>).
struct FoldTransposedScalarBroadcast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
if (!bcastOp)
return failure();

auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
if (!srcVectorType || srcVectorType.getNumElements() == 1) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
return success();
}

return failure();
}
};

// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
Expand Down Expand Up @@ -6161,12 +6139,106 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
///
/// Example:
/// ```
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
/// to vector<8x1xi32>
/// ```
/// can be rewritten as the equivalent
/// ```
/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
/// ```
/// The algorithm works by partitioning dimensions into groups that can be
/// locally permuted while preserving order, and checks that the transpose
/// only permutes within these groups.
///
/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
/// broadcasting from 1x1x4x1x1x7.
/// ^^^ ^ ^^^ ^
/// groups: 0 1 2 3
/// Order preserving permutations for this example are ones that only permute
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}

LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {

vector::BroadcastOp broadcast =
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcast) {
return rewriter.notifyMatchFailure(transpose,
"not preceded by a broadcast");
}

auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
VectorType outputType = transpose.getResultVectorType();

// transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
bool inputIsScalar = !inputType;
if (inputIsScalar) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
transpose.getVector());
return success();
}

ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank();
int64_t outputRank = transpose.getType().getRank();
int64_t deltaRank = outputRank - inputRank;

int low = 0;
for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
bool notOne = inputShape[inputIndex] != 1;
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
bool groupEndFound = notOne || prevNotOne;
if (groupEndFound) {
int high = inputIndex + deltaRank;
// Return failure if not all permutation destinations for indices in
// [low, high) are in [low, high), i.e. the permutation is not local to
// the group.
for (int i = low; i < high; ++i) {
if (permutation[i] < low || permutation[i] >= high) {
return rewriter.notifyMatchFailure(
transpose, "permutation not local to group");
}
}
}
}

// We don't need to check the final group [low, outputRank) because if it is
// not locally bound, there must be a preceding group that already failed
// the check (impossible to have just 1 non-locally bound group).

// The preceding logic also ensures that at this point, the output of the
// transpose is definitely broadcastable from the input shape, assert so:
assert(vector::isBroadcastableTo(inputType, outputType) ==
vector::BroadcastableToResult::Success &&
"not broadcastable directly to transpose output");

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
transpose.getVector());

return success();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
TransposeFolder, FoldTransposeSplat>(context);
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
FoldTransposeBroadcast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
24 changes: 0 additions & 24 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2218,30 +2218,6 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5

// -----

// CHECK-LABEL: func @transpose_scalar_broadcast1
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
// CHECK: return %[[V]] : vector<1x8xf32>
func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
%bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
return %t : vector<1x8xf32>
}

// -----

// CHECK-LABEL: func @transpose_scalar_broadcast2
// CHECK-SAME: (%[[ARG:.+]]: f32)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
// CHECK: return %[[V]] : vector<1x8xf32>
func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
%bcast = vector.broadcast %value : f32 to vector<8x1xf32>
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
return %t : vector<1x8xf32>
}

// -----

// CHECK-LABEL: func @transpose_splat_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
Expand Down
139 changes: 139 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file contains some canonicalizations tests involving vector.transpose.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, it's totally valid (and something I personally encourage) to document what pattern specifically is being tested:


// CHECK-LABEL: func @transpose_scalar_broadcast1
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
// CHECK: return %[[V]] : vector<1x8xf32>
func.func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
%bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
return %t : vector<1x8xf32>
}

// -----

// CHECK-LABEL: func @transpose_scalar_broadcast2
// CHECK-SAME: (%[[ARG:.+]]: f32)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
// CHECK: return %[[V]] : vector<1x8xf32>
func.func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
%bcast = vector.broadcast %value : f32 to vector<8x1xf32>
%t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
return %t : vector<1x8xf32>
}

// -----


// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
// CHECK: return %[[BC]] : vector<2x3x4xi8>
return %1 : vector<2x3x4xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
// CHECK: return %[[RES]] : vector<2x3x4xi8>
func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
return %1 : vector<2x3x4xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
// CHECK: return %[[RES]] : vector<8x1xi8>
func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
return %1 : vector<8x1xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_mixed_example
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x1x1x7xi8> to vector<3x2x4x5x6x7xi8>
// CHECK: return %[[RES]] : vector<3x2x4x5x6x7xi8>
func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vector<3x2x4x5x6x7xi8> {
%0 = vector.broadcast %arg0 : vector<4x1x1x7xi8> to vector<2x3x4x5x6x7xi8>
%1 = vector.transpose %0, [1, 0, 2, 3, 4, 5] : vector<2x3x4x5x6x7xi8> to vector<3x2x4x5x6x7xi8>
return %1 : vector<3x2x4x5x6x7xi8>
}

// -----

// CHECK-LABEL: broadcast_transpose_final_group
// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
// CHECK: return %[[RES]] : vector<4x7x2x3xi8>
func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
%0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
%1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
return %1 : vector<4x7x2x3xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_square
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
// CHECK: return %[[TRP]] : vector<4x4xi8>
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_hypercube
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
return %1 : vector<4x4x4x4xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_102
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
}

// -----

// CHECK-LABEL: negative_broadcast_transpose_021
// CHECK-SAME: %[[ARG:.*]]:
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
}

Loading