Skip to content

Commit 238706b

Browse files
authored
Generalize the tensor.gather to linalg.generic #412
The tensor.gather is not bufferized and here add a pass to transform it to linalg.generic so that it could be further lowered correctly.(Also add tensor::populateDecomposeTensorConcatPatterns to handle tensor.concat)
1 parent 3c716fe commit 238706b

File tree

5 files changed

+266
-1
lines changed

5 files changed

+266
-1
lines changed

include/gc/Transforms/Passes.td

+11-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def MergeAlloc : Pass<"gc-merge-alloc", "func::FuncOp"> {
3030
lifetime of the original memref before merging. This pass schedules the
3131
offsets to 1) make sure the offsets and address ranges do not overlap if
3232
two "mergeable" allocations have overlapped lifetime, and 2) reuse the
33-
address ranges that are considered "hot" in cache for an later allocation.
33+
address ranges that are considered "hot" in cache for an later allocation.
3434
}];
3535
let options = [
3636
Option<"optionAnalysisOnly", "analysis-only", "bool",
@@ -231,6 +231,16 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> {
231231
let description = [{
232232
Remove some useless tensor operations.
233233
}];
234+
let dependentDialects = [
235+
"tensor::TensorDialect",
236+
];
237+
}
238+
239+
def DecomposeTensorOperation : Pass<"decompose-tensor-operation"> {
240+
let summary = "decompose some tensor operation";
241+
let description = [{
242+
Decompose some tensor operations(concat, gather) into linalg operation.
243+
}];
234244
let dependentDialects = [
235245
"::mlir::tensor::TensorDialect"
236246
];

lib/gc/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ gc_add_mlir_library(GcPasses
2525
MergeAlloc.cpp
2626
MergeAllocTickBased.cpp
2727
FoldTensorOperation.cpp
28+
DecomposeTensorOperation.cpp
2829
LowerToTileVector.cpp
2930

3031
DEPENDS
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//===-- DecomposeTensorOperation.cpp - DESC ---------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "gc/Transforms/Passes.h"
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
11+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "llvm/Support/Casting.h"
15+
16+
namespace mlir {
17+
namespace gc {
18+
19+
#define GEN_PASS_DEF_DECOMPOSETENSOROPERATION
20+
#include "gc/Transforms/Passes.h.inc"
21+
namespace {
22+
23+
/// Decompose `tensor.gather` into `linalg.generic`.
24+
///
25+
/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
26+
/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
27+
///
28+
/// Becomes
29+
///
30+
/// %empty = tensor.empty() : tensor<1x7x128xf16>
31+
/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
32+
/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
33+
/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
34+
/// outs(%13 : tensor<1x7x128xf16>) {
35+
/// ^bb0(%in: index, %out: f16):
36+
/// %17 = linalg.index 2 : index
37+
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
38+
/// linalg.yield %extracted : f16
39+
/// } -> tensor<1x7x128xf16>
40+
struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
41+
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
42+
43+
SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
44+
Location loc,
45+
tensor::GatherOp gatherOp) const {
46+
SmallVector<OpFoldResult> dstSize =
47+
tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
48+
SmallVector<OpFoldResult> indexSize =
49+
tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
50+
SmallVector<OpFoldResult> srcSize =
51+
tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
52+
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
53+
bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
54+
dstSize.size() + gatherDims.size();
55+
for (size_t i = 0; i < indexSize.size() - 1; i++) {
56+
dstSize[i] = indexSize[i];
57+
}
58+
auto cnt = 0;
59+
for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
60+
while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
61+
cnt++;
62+
}
63+
dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
64+
? srcSize[cnt]
65+
: getAsIndexOpFoldResult(rewriter.getContext(), 1);
66+
cnt++;
67+
}
68+
return dstSize;
69+
}
70+
71+
LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
72+
PatternRewriter &rewriter) const override {
73+
OpBuilder::InsertionGuard g(rewriter);
74+
rewriter.setInsertionPoint(gatherOp);
75+
Location loc = gatherOp.getLoc();
76+
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
77+
78+
// create destination tensor for linalg out
79+
RankedTensorType dstType = gatherOp.getResultType();
80+
Value dstTensor = rewriter.create<tensor::EmptyOp>(
81+
loc, getDstMixedSizes(rewriter, loc, gatherOp),
82+
dstType.getElementType());
83+
84+
// split index tensor to create the linalg input
85+
SmallVector<Value> indexTensors;
86+
Value originIndexTensor = gatherOp.getIndices();
87+
SmallVector<OpFoldResult> indexTensorSize =
88+
tensor::getMixedSizes(rewriter, loc, originIndexTensor);
89+
SmallVector<OpFoldResult> indexTensorStride(
90+
indexTensorSize.size(),
91+
getAsIndexOpFoldResult(rewriter.getContext(), 1));
92+
SmallVector<OpFoldResult> indexTensorOffset(
93+
indexTensorSize.size(),
94+
getAsIndexOpFoldResult(rewriter.getContext(), 0));
95+
indexTensorSize[indexTensorSize.size() - 1] =
96+
getAsIndexOpFoldResult(rewriter.getContext(), 1);
97+
98+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
99+
indexTensorOffset[indexTensorSize.size() - 1] =
100+
getAsIndexOpFoldResult(rewriter.getContext(), cnt);
101+
Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
102+
loc, originIndexTensor, indexTensorOffset, indexTensorSize,
103+
indexTensorStride);
104+
indexTensors.emplace_back(indexTensor);
105+
}
106+
107+
// create the affine map
108+
SmallVector<AffineMap> affineMaps;
109+
SmallVector<AffineExpr> dimExprs;
110+
size_t dstRank = dstType.getShape().size();
111+
for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
112+
dimExprs.push_back(rewriter.getAffineDimExpr(i));
113+
dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
114+
115+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
116+
AffineMap currentMap =
117+
AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
118+
rewriter.getContext());
119+
affineMaps.emplace_back(currentMap);
120+
}
121+
affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
122+
123+
// create iterater types array
124+
SmallVector<utils::IteratorType> iteratorTypesArray(
125+
dstRank, utils::IteratorType::parallel);
126+
127+
// check whether the gather op is valid
128+
size_t srcRank = gatherOp.getSourceType().getShape().size();
129+
assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
130+
(indexTensorSize.size() - 1) + srcRank ==
131+
dstRank + gatherDims.size()) &&
132+
"Expected: index_size - 1 + source_size == dst_size or dst_szie - "
133+
"gather_size. \n");
134+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
135+
gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
136+
affineMaps, iteratorTypesArray,
137+
[&](OpBuilder &b, Location loc, ValueRange args) {
138+
SmallVector<Value> indexValues(srcRank);
139+
bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
140+
dstRank + gatherDims.size();
141+
int cnt = 0;
142+
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
143+
while (isShrinkDst &&
144+
llvm::find(gatherDims, cnt) != gatherDims.end()) {
145+
cnt++;
146+
}
147+
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
148+
cnt++;
149+
}
150+
for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
151+
indexValues[dim] = args[i];
152+
}
153+
154+
Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
155+
indexValues);
156+
b.create<linalg::YieldOp>(loc, extract);
157+
});
158+
return success();
159+
}
160+
};
161+
162+
/// DecomposeTensorOperationPass is a pass that decompose some tensor
163+
/// operations like tensor.gather, tensor.concat.
164+
struct DecomposeTensorOperationPass
165+
: public impl::DecomposeTensorOperationBase<DecomposeTensorOperationPass> {
166+
void runOnOperation() final {
167+
auto *ctx = &getContext();
168+
RewritePatternSet patterns(ctx);
169+
170+
patterns.add<DecomposeGatherOp>(patterns.getContext());
171+
tensor::populateDecomposeTensorConcatPatterns(patterns);
172+
173+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
174+
std::move(patterns)))) {
175+
return signalPassFailure();
176+
}
177+
}
178+
};
179+
} // namespace
180+
} // namespace gc
181+
} // namespace mlir

lib/gc/Transforms/GPU/Pipeline.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void populateGPUPipeline(OpPassManager &pm,
3535
pm.addNestedPass<func::FuncOp>(createAddContextArg());
3636
}
3737

38+
pm.addPass(createDecomposeTensorOperation());
3839
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
3940

4041
pm.addPass(bufferization::createEmptyTensorEliminationPass());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: gc-opt %s -decompose-tensor-operation --split-input-file | FileCheck %s
2+
3+
/// CHECK-LABEL: @gather_single_gather_dim
4+
func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
5+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
6+
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
7+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
8+
return %1 : tensor<2x3x2x2x2xf32>
9+
}
10+
11+
// -----
12+
13+
/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
14+
func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
15+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
16+
/// CHECK: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x2x2xf32>)
17+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
18+
return %1 : tensor<2x3x2x1x2x2xf32>
19+
}
20+
21+
// -----
22+
23+
/// CHECK-LABEL: @gather_multiple_gather_dim
24+
func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
25+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
26+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
27+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
28+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
29+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
30+
return %1 : tensor<2x3x2x2xf32>
31+
}
32+
33+
// -----
34+
35+
/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
36+
func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
37+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x1x2xf32>
38+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
39+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
40+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x1x2xf32>)
41+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
42+
return %1 : tensor<2x3x2x1x1x2xf32>
43+
}
44+
45+
// -----
46+
47+
/// CHECK-LABEL: @gather_single_gather_dim_dynamic
48+
func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
49+
/// CHECK: %[[DIM1:.*]] = tensor.dim
50+
/// CHECK: %[[DIM2:.*]] = tensor.dim
51+
/// CHECK: %[[DIM3:.*]] = tensor.dim
52+
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
53+
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
54+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
55+
return %1 : tensor<2x3x?x?x?xf32>
56+
}
57+
58+
// -----
59+
60+
/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
61+
func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
62+
/// CHECK: %[[DIM1:.*]] = tensor.dim
63+
/// CHECK: %[[DIM2:.*]] = tensor.dim
64+
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
65+
/// CHECK: %[[DIM3:.*]] = tensor.dim
66+
/// CHECK: %[[DIM4:.*]] = tensor.dim
67+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
68+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
69+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
70+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
71+
return %1 : tensor<?x?x2x1x1x2xf32>
72+
}

0 commit comments

Comments
 (0)