Skip to content

Commit 42a1a9f

Browse files
zhczhongkurapov-peter
authored andcommitted
support dyanmic case
1 parent 7f8483b commit 42a1a9f

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

lib/gc/Transforms/DecomposeTensorOperation.cpp

+31-4
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,37 @@ namespace {
3737
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
3838
/// linalg.yield %extracted : f16
3939
/// } -> tensor<1x7x128xf16>
40-
4140
struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
4241
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
4342

43+
SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
44+
Location loc,
45+
tensor::GatherOp gatherOp) const {
46+
SmallVector<OpFoldResult> dstSize =
47+
tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
48+
SmallVector<OpFoldResult> indexSize =
49+
tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
50+
SmallVector<OpFoldResult> srcSize =
51+
tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
52+
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
53+
bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
54+
dstSize.size() + gatherDims.size();
55+
for (size_t i = 0; i < indexSize.size() - 1; i++) {
56+
dstSize[i] = indexSize[i];
57+
}
58+
auto cnt = 0;
59+
for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
60+
while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
61+
cnt++;
62+
}
63+
dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
64+
? srcSize[cnt]
65+
: getAsIndexOpFoldResult(rewriter.getContext(), 1);
66+
cnt++;
67+
}
68+
return dstSize;
69+
}
70+
4471
LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
4572
PatternRewriter &rewriter) const override {
4673
OpBuilder::InsertionGuard g(rewriter);
@@ -51,7 +78,7 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
5178
// create destination tensor for linalg out
5279
RankedTensorType dstType = gatherOp.getResultType();
5380
Value dstTensor = rewriter.create<tensor::EmptyOp>(
54-
loc, tensor::getMixedSizes(rewriter, loc, gatherOp.getResult()),
81+
loc, getDstMixedSizes(rewriter, loc, gatherOp),
5582
dstType.getElementType());
5683

5784
// split index tensor to create the linalg input
@@ -113,8 +140,8 @@ struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
113140
dstRank + gatherDims.size();
114141
int cnt = 0;
115142
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
116-
while (llvm::find(gatherDims, cnt) != gatherDims.end() &&
117-
isShrinkDst) {
143+
while (isShrinkDst &&
144+
llvm::find(gatherDims, cnt) != gatherDims.end()) {
118145
cnt++;
119146
}
120147
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);

test/mlir/test/gc/Transforms/DecomposeTensorOperation.mlir

+29
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,32 @@ func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg
4141
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
4242
return %1 : tensor<2x3x2x1x1x2xf32>
4343
}
44+
45+
// -----
46+
47+
/// CHECK-LABEL: @gather_single_gather_dim_dynamic
48+
func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
49+
/// CHECK: %[[DIM1:.*]] = tensor.dim
50+
/// CHECK: %[[DIM2:.*]] = tensor.dim
51+
/// CHECK: %[[DIM3:.*]] = tensor.dim
52+
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
53+
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
54+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
55+
return %1 : tensor<2x3x?x?x?xf32>
56+
}
57+
58+
// -----
59+
60+
/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
61+
func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
62+
/// CHECK: %[[DIM1:.*]] = tensor.dim
63+
/// CHECK: %[[DIM2:.*]] = tensor.dim
64+
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
65+
/// CHECK: %[[DIM3:.*]] = tensor.dim
66+
/// CHECK: %[[DIM4:.*]] = tensor.dim
67+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
68+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
69+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
70+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
71+
return %1 : tensor<?x?x2x1x1x2xf32>
72+
}

0 commit comments

Comments
 (0)