Skip to content

Commit ff92faa

Browse files
committed
extract as large a chunk as possible in shape_cast lowering
1 parent 1b6b036 commit ff92faa

File tree

2 files changed

+99
-32
lines changed

2 files changed

+99
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ using namespace mlir;
2828
using namespace mlir::vector;
2929

3030
/// Increments n-D `indices` by `step` starting from the innermost dimension.
31-
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
31+
static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
3232
int step = 1) {
3333
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34-
assert(indices[dim] < vecType.getDimSize(dim) &&
35-
"Indices are out of bound");
34+
int64_t dimSize = shape[dim];
35+
assert(indices[dim] < dimSize && "Indices are out of bound");
36+
3637
indices[dim] += step;
37-
if (indices[dim] < vecType.getDimSize(dim))
38+
39+
int64_t spill = indices[dim] / dimSize;
40+
if (spill == 0)
3841
break;
3942

40-
indices[dim] = 0;
41-
step = 1;
43+
indices[dim] %= dimSize;
44+
step = spill;
4245
}
4346
}
4447

@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
7982
// and destination slice insertion and generate such instructions.
8083
for (int64_t i = 0; i < numElts; ++i) {
8184
if (i != 0) {
82-
incIdx(srcIdx, sourceVectorType, /*step=*/1);
83-
incIdx(resIdx, resultVectorType, /*step=*/extractSize);
85+
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
86+
incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
8487
}
8588

8689
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
131134
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
132135
for (int64_t i = 0; i < numElts; ++i) {
133136
if (i != 0) {
134-
incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
135-
incIdx(resIdx, resultVectorType, /*step=*/1);
137+
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
138+
incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
136139
}
137140

138141
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
157160
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
158161
PatternRewriter &rewriter) const override {
159162
Location loc = op.getLoc();
160-
auto sourceVectorType = op.getSourceVectorType();
161-
auto resultVectorType = op.getResultVectorType();
163+
VectorType sourceType = op.getSourceVectorType();
164+
VectorType resultType = op.getResultVectorType();
162165

163-
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
166+
if (sourceType.isScalable() || resultType.isScalable())
164167
return failure();
165168

166-
// Special case for n-D / 1-D lowerings with better implementations.
167-
int64_t srcRank = sourceVectorType.getRank();
168-
int64_t resRank = resultVectorType.getRank();
169-
if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
169+
// Special case for n-D / 1-D lowerings with implementations that use
170+
// extract_strided_slice / insert_strided_slice.
171+
int64_t sourceRank = sourceType.getRank();
172+
int64_t resultRank = resultType.getRank();
173+
if ((sourceRank > 1 && resultRank == 1) ||
174+
(sourceRank == 1 && resultRank > 1))
170175
return failure();
171176

172-
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
173-
// extract/insert chains.
174-
int64_t numElts = 1;
175-
for (int64_t r = 0; r < srcRank; r++)
176-
numElts *= sourceVectorType.getDimSize(r);
177+
int64_t numExtracts = sourceType.getNumElements();
178+
int64_t nbCommonInnerDims = 0;
179+
while (true) {
180+
int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
181+
int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
182+
if (sourceDim < 0 || resultDim < 0)
183+
break;
184+
int64_t dimSize = sourceType.getDimSize(sourceDim);
185+
if (dimSize != resultType.getDimSize(resultDim))
186+
break;
187+
numExtracts /= dimSize;
188+
++nbCommonInnerDims;
189+
}
190+
177191
// Replace with data movement operations:
178192
// x[0,0,0] = y[0,0]
179193
// x[0,0,1] = y[0,1]
180194
// x[0,1,0] = y[0,2]
181195
// etc., incrementing the two index vectors "row-major"
182196
// within the source and result shape.
183-
SmallVector<int64_t> srcIdx(srcRank, 0);
184-
SmallVector<int64_t> resIdx(resRank, 0);
185-
Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
186-
for (int64_t i = 0; i < numElts; i++) {
197+
SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
198+
SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
199+
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
200+
201+
for (int64_t i = 0; i < numExtracts; i++) {
187202
if (i != 0) {
188-
incIdx(srcIdx, sourceVectorType);
189-
incIdx(resIdx, resultVectorType);
203+
incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
204+
incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
190205
}
191206

192207
Value extract =
193-
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
194-
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
208+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
209+
result =
210+
rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
195211
}
196212
rewriter.replaceOp(op, result);
197213
return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
329345

330346
// 4. Increment the insert/extract indices, stepping by minExtractionSize
331347
// for the trailing dimensions.
332-
incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
333-
incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
348+
incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
349+
incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
334350
}
335351

336352
rewriter.replaceOp(op, result);

mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,57 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
140140
return %s : vector<f32>
141141
}
142142

143+
144+
// CHECK-LABEL: func.func @shape_cast_squeeze_leading_one(
145+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
146+
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
147+
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
148+
// CHECK: return %[[EXTRACTED]] : vector<2x3xf32>
149+
func.func @shape_cast_squeeze_leading_one(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
150+
%s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
151+
return %s : vector<2x3xf32>
152+
}
153+
154+
// CHECK-LABEL: func.func @shape_cast_squeeze_middle_one(
155+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
156+
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
157+
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
158+
// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
159+
// CHECK-SAME: into vector<2x3xf32>
160+
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
161+
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
162+
// CHECK-SAME: into vector<2x3xf32>
163+
// CHECK: return %[[I1]] : vector<2x3xf32>
164+
func.func @shape_cast_squeeze_middle_one(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
165+
%s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
166+
return %s : vector<2x3xf32>
167+
}
168+
169+
// CHECK-LABEL: func.func @shape_cast_unsqueeze_leading_one(
170+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
171+
// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
172+
// CHECK: %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
173+
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
174+
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
175+
func.func @shape_cast_unsqueeze_leading_one(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
176+
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
177+
return %s : vector<1x2x3xf32>
178+
}
179+
180+
// CHECK-LABEL: func.func @shape_cast_unsqueeze_middle_one(
181+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
182+
// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
183+
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
184+
// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
185+
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
186+
// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
187+
// CHECK: return %[[I1]] : vector<2x1x3xf32>
188+
func.func @shape_cast_unsqueeze_middle_one(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
189+
%s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
190+
return %s : vector<2x1x3xf32>
191+
}
192+
193+
143194
module attributes {transform.with_named_sequence} {
144195
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
145196
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)