-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Improve shape_cast lowering #140800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ff92faa
to
6c0bc5d
Compare
@llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesBefore this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example: %0 = vector.shape_cast %arg0 : vector<10x2x3xf32> to vector<2x5x2x3xf32> has common suffix of shape This case first mentioned here: #138777 (comment) Full diff: https://github.com/llvm/llvm-project/pull/140800.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;
/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
+ int64_t dimSize = shape[dim];
+ assert(indices[dim] < dimSize && "Indices are out of bound");
+
indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
+
+ int64_t spill = indices[dim] / dimSize;
+ if (spill == 0)
break;
- indices[dim] = 0;
- step = 1;
+ indices[dim] %= dimSize;
+ step = spill;
}
}
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
}
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
}
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+ if (sourceType.isScalable() || resultType.isScalable())
return failure();
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+ // Special case for n-D / 1-D lowerings with implementations that use
+ // extract_strided_slice / insert_strided_slice.
+ int64_t sourceRank = sourceType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if ((sourceRank > 1 && resultRank == 1) ||
+ (sourceRank == 1 && resultRank > 1))
return failure();
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
+ int64_t numExtracts = sourceType.getNumElements();
+ int64_t nbCommonInnerDims = 0;
+ while (true) {
+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+ if (sourceDim < 0 || resultDim < 0)
+ break;
+ int64_t dimSize = sourceType.getDimSize(sourceDim);
+ if (dimSize != resultType.getDimSize(resultDim))
+ break;
+ numExtracts /= dimSize;
+ ++nbCommonInnerDims;
+ }
+
// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; i++) {
+ SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+ SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+ for (int64_t i = 0; i < numExtracts; i++) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType);
- incIdx(resIdx, resultVectorType);
+ incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+ incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
}
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
}
rewriter.replaceOp(op, result);
return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @prepend_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+ return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesBefore this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example: %0 = vector.shape_cast %arg0 : vector<10x2x3xf32> to vector<2x5x2x3xf32> has common suffix of shape This case first mentioned here: #138777 (comment) Full diff: https://github.com/llvm/llvm-project/pull/140800.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;
/// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
int step = 1) {
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
- assert(indices[dim] < vecType.getDimSize(dim) &&
- "Indices are out of bound");
+ int64_t dimSize = shape[dim];
+ assert(indices[dim] < dimSize && "Indices are out of bound");
+
indices[dim] += step;
- if (indices[dim] < vecType.getDimSize(dim))
+
+ int64_t spill = indices[dim] / dimSize;
+ if (spill == 0)
break;
- indices[dim] = 0;
- step = 1;
+ indices[dim] %= dimSize;
+ step = spill;
}
}
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
// and destination slice insertion and generate such instructions.
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/1);
- incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
}
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
- incIdx(resIdx, resultVectorType, /*step=*/1);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
}
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType resultType = op.getResultVectorType();
- if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+ if (sourceType.isScalable() || resultType.isScalable())
return failure();
- // Special case for n-D / 1-D lowerings with better implementations.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+ // Special case for n-D / 1-D lowerings with implementations that use
+ // extract_strided_slice / insert_strided_slice.
+ int64_t sourceRank = sourceType.getRank();
+ int64_t resultRank = resultType.getRank();
+ if ((sourceRank > 1 && resultRank == 1) ||
+ (sourceRank == 1 && resultRank > 1))
return failure();
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
+ int64_t numExtracts = sourceType.getNumElements();
+ int64_t nbCommonInnerDims = 0;
+ while (true) {
+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+ if (sourceDim < 0 || resultDim < 0)
+ break;
+ int64_t dimSize = sourceType.getDimSize(sourceDim);
+ if (dimSize != resultType.getDimSize(resultDim))
+ break;
+ numExtracts /= dimSize;
+ ++nbCommonInnerDims;
+ }
+
// Replace with data movement operations:
// x[0,0,0] = y[0,0]
// x[0,0,1] = y[0,1]
// x[0,1,0] = y[0,2]
// etc., incrementing the two index vectors "row-major"
// within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank, 0);
- SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
- for (int64_t i = 0; i < numElts; i++) {
+ SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+ SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+ for (int64_t i = 0; i < numExtracts; i++) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType);
- incIdx(resIdx, resultVectorType);
+ incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+ incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
}
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
}
rewriter.replaceOp(op, result);
return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
- incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+ incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+ incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
return %s : vector<f32>
}
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME: into vector<2x3xf32>
+// CHECK: return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func.func @prepend_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+ return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL: func.func @insert_middle_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK: return %[[I1]] : vector<2x1x3xf32>
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
6c0bc5d
to
1a4b759
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice improvements, thanks!
Please find some minor comments inline :)
|
||
if (sourceVectorType.isScalable() || resultVectorType.isScalable()) | ||
if (sourceType.isScalable() || resultType.isScalable()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are refactoring a fair bit, would you mind replacing this (and other instances of failure
) with notifyMatchFailure
? Thanks!
// Special case for n-D / 1-D lowerings with implementations that use | ||
// extract_strided_slice / insert_strided_slice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we clarify this comment? (the original part is quite confusing). Right now, combined with the code, it reads a bit like:
This is a special case, lets fail!
😅 I assume that it was meant to be:
This special case is handled by other, more optimal patterns.
Or something similar :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't much like the logic spread over the 3 patterns (N->N, 1->N, N->1) as there isn't really anything special about then 1->N and N->1 cases. So I've done a fairly major update to the N->N pattern, so that now it handles then 1->N and N->1 cases. As the new change is quite significant, if you'd prefer it to be done in a separate PR I'm happy to postpone this 'unification', backtrack, and just make the minor suggestions to this PR that you suggested.
I also unified the tests across the test file. The behavior for the 1->N and N->1 cases is unchanged by this PR though.
int64_t numExtracts = sourceType.getNumElements(); | ||
int64_t nbCommonInnerDims = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do both num
and nb
stand for number? Could you unify?
@@ -28,17 +28,20 @@ using namespace mlir; | |||
using namespace mlir::vector; | |||
|
|||
/// Increments n-D `indices` by `step` starting from the innermost dimension. | |||
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType, | |||
static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps just a slow day for me, but it took me a while to follow the logic in this method. Let me share my observations:
step
only really defines the step for the trailing dim? If yes, it would be good to update the variable name.spill
is either0
or1
.
is this correct?
Btw, extra documentation would help. My initial interpretation was: "Update every single index by 1", but that's not true, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored and documented this significantly in the latest commit, it is hopefully now clearer
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32> | ||
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indention is off.
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] : | ||
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indentation is inconsistent with what's used in @squeeze_out_middle_unit_dim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a significant test file refactor to make the pre-existing and new tests consistent. The pre-existing test logic is unchanged. I'm happy to postpone refactoring the old tests to make that clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy for you to keep all these changes in this PR - this is effectively a proper refactor of the pattern, which was in dire need of some TLC anyway 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @newling - I’m happy with the expanded scope; the overall direction looks very positive!
I’m slowly running out of cycles this week and will be travelling next week. No need to wait for my approval if another reviewer signs off. I’ll also try to post a few comments while in transit :)
/// Two special cases are handled seperately: | ||
/// (1) A shape_cast that just does leading 1 insertion/removal | ||
/// (2) A shape_cast where the gcd is 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are these cases handled?
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] : | ||
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy for you to keep all these changes in this PR - this is effectively a proper refactor of the pattern, which was in dire need of some TLC anyway 🙂
Hi @banach-space - ok no worries, thanks for the heads up. And thanks for the numerous and useful reviews :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome improvement. LGTM % refactoring
/// +-----------------> gcd(4,6) is 2 | | | ||
/// | | | | ||
/// v v v | ||
/// atomic shape <----- 2x7x11 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This makes sense to me...
// IR is generated in this case if we just extract and insert the elements | ||
// directly. In other words, we don't use extract_strided_slice and | ||
// insert_strided_slice. | ||
if (greatestCommonDivisor == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we refactor the different cases to functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, looks a bit better now IMO
Before this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example:
has common suffix of shape
2x3
. Before this PR, this would be lowered to 60 extract/insert pairs with extracts of the formvector.extract %arg0 [a, b, c] : f32 from vector<10x2x3xf32>
. With this PR it is 10 extract/insert pairs with extracts of the formvector.extract %arg0 [a] : vector<2x3xf32> from vector<10x2x3xf32>
.This case first mentioned here: #138777 (comment)
General direction of travel:
shape_cast lowering needs to be at least as 'good' as lowering of broadcast, transpose, and extract for the canonicalizations proposed (#138777 and #140583) Eventually I would like to move shape_cast lowering (i.e. vector->vector) to the later conversion step (i.e. vector -> llvm/spirv)