@@ -37,10 +37,37 @@ namespace {
37
37
// / %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
38
38
// / linalg.yield %extracted : f16
39
39
// / } -> tensor<1x7x128xf16>
40
-
41
40
struct DecomposeGatherOp : public OpRewritePattern <tensor::GatherOp> {
42
41
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
43
42
43
+ SmallVector<OpFoldResult> getDstMixedSizes (PatternRewriter &rewriter,
44
+ Location loc,
45
+ tensor::GatherOp gatherOp) const {
46
+ SmallVector<OpFoldResult> dstSize =
47
+ tensor::getMixedSizes (rewriter, loc, gatherOp.getResult ());
48
+ SmallVector<OpFoldResult> indexSize =
49
+ tensor::getMixedSizes (rewriter, loc, gatherOp.getIndices ());
50
+ SmallVector<OpFoldResult> srcSize =
51
+ tensor::getMixedSizes (rewriter, loc, gatherOp.getSource ());
52
+ SmallVector<int64_t > gatherDims (gatherOp.getGatherDims ());
53
+ bool isShrinkDst = (indexSize.size () - 1 ) + srcSize.size () ==
54
+ dstSize.size () + gatherDims.size ();
55
+ for (size_t i = 0 ; i < indexSize.size () - 1 ; i++) {
56
+ dstSize[i] = indexSize[i];
57
+ }
58
+ auto cnt = 0 ;
59
+ for (size_t i = indexSize.size () - 1 ; i < dstSize.size (); i++) {
60
+ while (isShrinkDst && llvm::find (gatherDims, cnt) != gatherDims.end ()) {
61
+ cnt++;
62
+ }
63
+ dstSize[i] = llvm::find (gatherDims, cnt) == gatherDims.end ()
64
+ ? srcSize[cnt]
65
+ : getAsIndexOpFoldResult (rewriter.getContext (), 1 );
66
+ cnt++;
67
+ }
68
+ return dstSize;
69
+ }
70
+
44
71
LogicalResult matchAndRewrite (tensor::GatherOp gatherOp,
45
72
PatternRewriter &rewriter) const override {
46
73
OpBuilder::InsertionGuard g (rewriter);
@@ -51,7 +78,7 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
51
78
// create destination tensor for linalg out
52
79
RankedTensorType dstType = gatherOp.getResultType ();
53
80
Value dstTensor = rewriter.create <tensor::EmptyOp>(
54
- loc, tensor::getMixedSizes (rewriter, loc, gatherOp. getResult () ),
81
+ loc, getDstMixedSizes (rewriter, loc, gatherOp),
55
82
dstType.getElementType ());
56
83
57
84
// split index tensor to create the linalg input
@@ -113,8 +140,8 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
113
140
dstRank + gatherDims.size ();
114
141
int cnt = 0 ;
115
142
for (auto i = indexTensorSize.size () - 1 ; i < dstRank; i++) {
116
- while (llvm::find (gatherDims, cnt) != gatherDims. end () &&
117
- isShrinkDst ) {
143
+ while (isShrinkDst &&
144
+ llvm::find (gatherDims, cnt) != gatherDims. end () ) {
118
145
cnt++;
119
146
}
120
147
indexValues[cnt] = b.create <linalg::IndexOp>(loc, i);
0 commit comments