Skip to content

Commit 61e90af

Browse files
authored
[Transform][Vector] Lower operation to tile (virtual) vector (#252)
1 parent 23849b3 commit 61e90af

File tree

6 files changed

+971
-0
lines changed

6 files changed

+971
-0
lines changed

include/gc/Transforms/Passes.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,29 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
117117
let dependentDialects = ["scf::SCFDialect"];
118118
}
119119

120+
def FoldTensorOperation : Pass<"fold-tensor-operation"> {
121+
let summary = "Fold some tensor operation";
122+
let description = [{
123+
Remove some useless tensor operations.
124+
}];
125+
let dependentDialects = [
126+
"::mlir::tensor::TensorDialect"
127+
];
128+
}
129+
130+
def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> {
131+
let summary = "Lower tensor to tile (virtual) vector";
132+
let description = [{
133+
Lower operation operate on tensor to vector operation.
134+
}];
135+
let dependentDialects = [
136+
"::mlir::func::FuncDialect",
137+
"::mlir::math::MathDialect",
138+
"::mlir::arith::ArithDialect",
139+
"::mlir::tensor::TensorDialect",
140+
"::mlir::linalg::LinalgDialect",
141+
"::mlir::vector::VectorDialect",
142+
];
143+
}
144+
120145
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ gc_add_mlir_library(GcPasses
2020
TilingUtil.cpp
2121
SinkOpIntoInnerLoop.cpp
2222
MergeNestedForall.cpp
23+
FoldTensorOperation.cpp
24+
LowerToTileVector.cpp
2325

2426
DEPENDS
2527
GraphCompilerPassIncGen
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===-- FoldTensorOperation.cpp - fold tensor op ----------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "gc/Transforms/Passes.h"
9+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
10+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "llvm/Support/Casting.h"
14+
15+
namespace mlir {
16+
namespace gc {
17+
18+
#define GEN_PASS_DEF_FOLDTENSOROPERATION
19+
#include "gc/Transforms/Passes.h.inc"
20+
namespace {
21+
22+
/// FoldTensorOperationPass is a pass that fold some useless tensor
23+
/// operation.
24+
struct FoldTensorOperationPass
25+
: public impl::FoldTensorOperationBase<FoldTensorOperationPass> {
26+
void runOnOperation() final {
27+
//
28+
auto *ctx = &getContext();
29+
RewritePatternSet pattern(ctx);
30+
31+
tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
32+
Operation *producer = fusedOperand->get().getDefiningOp();
33+
return producer && producer->hasOneUse();
34+
};
35+
// Some operation convert as constant, this pattern can help us to improve
36+
// the performance.
37+
tensor::populateRewriteAsConstantPatterns(pattern, defaultControlFn);
38+
// Remove unnessary operation like extract slice and insert slice
39+
tensor::populateReassociativeReshapeFoldingPatterns(pattern);
40+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(pattern);
41+
tensor::populateFoldTensorSubsetOpPatterns(pattern);
42+
43+
GreedyRewriteConfig config;
44+
// Use to remove useless tensor operation like extract or
45+
// insert slice.
46+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
47+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
48+
config);
49+
}
50+
};
51+
} // namespace
52+
} // namespace gc
53+
} // namespace mlir

0 commit comments

Comments
 (0)