@@ -28,17 +28,20 @@ using namespace mlir;
28
28
using namespace mlir ::vector;
29
29
30
30
// / Increments n-D `indices` by `step` starting from the innermost dimension.
31
- static void incIdx (SmallVectorImpl <int64_t > & indices, VectorType vecType ,
31
+ static void incIdx (MutableArrayRef <int64_t > indices, ArrayRef< int64_t > shape ,
32
32
int step = 1 ) {
33
33
for (int dim : llvm::reverse (llvm::seq<int >(0 , indices.size ()))) {
34
- assert (indices[dim] < vecType.getDimSize (dim) &&
35
- " Indices are out of bound" );
34
+ int64_t dimSize = shape[dim];
35
+ assert (indices[dim] < dimSize && " Indices are out of bound" );
36
+
36
37
indices[dim] += step;
37
- if (indices[dim] < vecType.getDimSize (dim))
38
+
39
+ int64_t spill = indices[dim] / dimSize;
40
+ if (spill == 0 )
38
41
break ;
39
42
40
- indices[dim] = 0 ;
41
- step = 1 ;
43
+ indices[dim] %= dimSize ;
44
+ step = spill ;
42
45
}
43
46
}
44
47
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
79
82
// and destination slice insertion and generate such instructions.
80
83
for (int64_t i = 0 ; i < numElts; ++i) {
81
84
if (i != 0 ) {
82
- incIdx (srcIdx, sourceVectorType, /* step=*/ 1 );
83
- incIdx (resIdx, resultVectorType, /* step=*/ extractSize);
85
+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ 1 );
86
+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ extractSize);
84
87
}
85
88
86
89
Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
131
134
Value result = rewriter.create <ub::PoisonOp>(loc, resultVectorType);
132
135
for (int64_t i = 0 ; i < numElts; ++i) {
133
136
if (i != 0 ) {
134
- incIdx (srcIdx, sourceVectorType, /* step=*/ extractSize);
135
- incIdx (resIdx, resultVectorType, /* step=*/ 1 );
137
+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ extractSize);
138
+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ 1 );
136
139
}
137
140
138
141
Value extract = rewriter.create <vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
157
160
LogicalResult matchAndRewrite (vector::ShapeCastOp op,
158
161
PatternRewriter &rewriter) const override {
159
162
Location loc = op.getLoc ();
160
- auto sourceVectorType = op.getSourceVectorType ();
161
- auto resultVectorType = op.getResultVectorType ();
163
+ VectorType sourceType = op.getSourceVectorType ();
164
+ VectorType resultType = op.getResultVectorType ();
162
165
163
- if (sourceVectorType .isScalable () || resultVectorType .isScalable ())
166
+ if (sourceType .isScalable () || resultType .isScalable ())
164
167
return failure ();
165
168
166
- // Special case for n-D / 1-D lowerings with better implementations.
167
- int64_t srcRank = sourceVectorType.getRank ();
168
- int64_t resRank = resultVectorType.getRank ();
169
- if ((srcRank > 1 && resRank == 1 ) || (srcRank == 1 && resRank > 1 ))
169
+ // Special case for n-D / 1-D lowerings with implementations that use
170
+ // extract_strided_slice / insert_strided_slice.
171
+ int64_t sourceRank = sourceType.getRank ();
172
+ int64_t resultRank = resultType.getRank ();
173
+ if ((sourceRank > 1 && resultRank == 1 ) ||
174
+ (sourceRank == 1 && resultRank > 1 ))
170
175
return failure ();
171
176
172
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
173
- // extract/insert chains.
174
- int64_t numElts = 1 ;
175
- for (int64_t r = 0 ; r < srcRank; r++)
176
- numElts *= sourceVectorType.getDimSize (r);
177
+ int64_t numExtracts = sourceType.getNumElements ();
178
+ int64_t nbCommonInnerDims = 0 ;
179
+ while (true ) {
180
+ int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
181
+ int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
182
+ if (sourceDim < 0 || resultDim < 0 )
183
+ break ;
184
+ int64_t dimSize = sourceType.getDimSize (sourceDim);
185
+ if (dimSize != resultType.getDimSize (resultDim))
186
+ break ;
187
+ numExtracts /= dimSize;
188
+ ++nbCommonInnerDims;
189
+ }
190
+
177
191
// Replace with data movement operations:
178
192
// x[0,0,0] = y[0,0]
179
193
// x[0,0,1] = y[0,1]
180
194
// x[0,1,0] = y[0,2]
181
195
// etc., incrementing the two index vectors "row-major"
182
196
// within the source and result shape.
183
- SmallVector<int64_t > srcIdx (srcRank, 0 );
184
- SmallVector<int64_t > resIdx (resRank, 0 );
185
- Value result = rewriter.create <ub::PoisonOp>(loc, resultVectorType);
186
- for (int64_t i = 0 ; i < numElts; i++) {
197
+ SmallVector<int64_t > sourceIndex (sourceRank - nbCommonInnerDims, 0 );
198
+ SmallVector<int64_t > resultIndex (resultRank - nbCommonInnerDims, 0 );
199
+ Value result = rewriter.create <ub::PoisonOp>(loc, resultType);
200
+
201
+ for (int64_t i = 0 ; i < numExtracts; i++) {
187
202
if (i != 0 ) {
188
- incIdx (srcIdx, sourceVectorType );
189
- incIdx (resIdx, resultVectorType );
203
+ incIdx (sourceIndex, sourceType. getShape (). drop_back (nbCommonInnerDims) );
204
+ incIdx (resultIndex, resultType. getShape (). drop_back (nbCommonInnerDims) );
190
205
}
191
206
192
207
Value extract =
193
- rewriter.create <vector::ExtractOp>(loc, op.getSource (), srcIdx);
194
- result = rewriter.create <vector::InsertOp>(loc, extract, result, resIdx);
208
+ rewriter.create <vector::ExtractOp>(loc, op.getSource (), sourceIndex);
209
+ result =
210
+ rewriter.create <vector::InsertOp>(loc, extract, result, resultIndex);
195
211
}
196
212
rewriter.replaceOp (op, result);
197
213
return success ();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
329
345
330
346
// 4. Increment the insert/extract indices, stepping by minExtractionSize
331
347
// for the trailing dimensions.
332
- incIdx (srcIdx, sourceVectorType, /* step=*/ minExtractionSize);
333
- incIdx (resIdx, resultVectorType, /* step=*/ minExtractionSize);
348
+ incIdx (srcIdx, sourceVectorType. getShape () , /* step=*/ minExtractionSize);
349
+ incIdx (resIdx, resultVectorType. getShape () , /* step=*/ minExtractionSize);
334
350
}
335
351
336
352
rewriter.replaceOp (op, result);
0 commit comments