-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesHandles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example
can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that
The pattern Full diff: https://github.com/llvm/llvm-project/pull/135841.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..5da0ef0af032f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5621,6 +5621,29 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+ int64_t current = 0;
+ for (auto p : permutation) {
+ if (inShape[p] != 1) {
+ if (p < current) {
+ return false;
+ }
+ current = p;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
@@ -5629,13 +5652,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
- // Canceling shape casts.
- if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
- VectorType srcType = otherOp.getSource().getType();
+ // shape_cast(something(x)) -> x, or
+ // -> shape_cast(x).
+ //
+ // Confirms that a new shape_cast will have valid semantics (expands OR
+ // collapses dimensions).
+ auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+ VectorType srcType = source.getType();
if (resultType == srcType)
- return otherOp.getSource();
+ return source;
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -5645,8 +5670,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
- setOperand(otherOp.getSource());
+ setOperand(source);
return getResult();
+ };
+
+ // Canceling shape casts.
+ if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+ TypedValue<VectorType> source = otherOp.getSource();
+ return maybeFold(source);
+ }
+
+ // shape_cast(transpose(x)) -> shape_cast(x)
+ if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ if (transpose.getType().isScalable())
+ return {};
+ if (isOrderPreserving(transpose)) {
+ TypedValue<VectorType> source = transpose.getVector();
+ return maybeFold(source);
+ }
+ return {};
}
// Cancelling broadcast and shape cast ops.
@@ -5675,7 +5717,7 @@ namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
-/// vector<4x1x1xi1> --> vector<4x1>
+/// vector<4x1x1xi1> --> vector<4x1xi1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6161,12 +6203,40 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto shapeCastOp =
+ transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+ if (!isOrderPreserving(transposeOp))
+ return failure();
+ if (transposeOp.getType().isScalable())
+ return failure();
+
+ VectorType resultType = transposeOp.getType();
+
+ // We don't need to check isValidShapeCast at this point, because it is
+ // guaranteed that merging the transpose into the the shape_cast is a valid
+ // shape_cast, because the transpose just inserts/removes ones.
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+ shapeCastOp.getSource());
+ return success();
+ }
+};
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ FoldTransposeShapeCast, TransposeFolder, FoldTransposeSplat>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..10144cb9034e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3295,3 +3295,67 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3]
+ : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+ %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<6x1x1xi8> to vector<6x1x1xi8>
+ return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+ %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+ return %1 : vector<2x3xi8>
+}
|
b21a4a6
to
f4ae206
Compare
Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example
can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that
The pattern
ConvertIllegalShapeCastOpsToTransposes
that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors