1
+ // ===-- FoldTensorOperation.cpp - fold tensor op ----------------*- C++ -*-===//
2
+ //
3
+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+ #include " gc/Transforms/Passes.h"
9
+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
10
+ #include " mlir/Dialect/Tensor/Transforms/Transforms.h"
11
+ #include " mlir/IR/PatternMatch.h"
12
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
+ #include " llvm/Support/Casting.h"
14
+
15
+ namespace mlir {
16
+ namespace gc {
17
+
18
+ #define GEN_PASS_DEF_FOLDTENSOROPERATION
19
+ #include " gc/Transforms/Passes.h.inc"
20
+ namespace {
21
+
22
+ // / FoldTensorOperationPass is a pass that fold some useless tensor
23
+ // / operation.
24
+ struct FoldTensorOperationPass
25
+ : public impl::FoldTensorOperationBase<FoldTensorOperationPass> {
26
+ void runOnOperation () final {
27
+ //
28
+ auto *ctx = &getContext ();
29
+ RewritePatternSet pattern (ctx);
30
+
31
+ tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
32
+ Operation *producer = fusedOperand->get ().getDefiningOp ();
33
+ return producer && producer->hasOneUse ();
34
+ };
35
+ // Some operation convert as constant, this pattern can help us to improve
36
+ // the performance.
37
+ tensor::populateRewriteAsConstantPatterns (pattern, defaultControlFn);
38
+ // Remove unnessary operation like extract slice and insert slice
39
+ tensor::populateReassociativeReshapeFoldingPatterns (pattern);
40
+ tensor::populateMergeConsecutiveInsertExtractSlicePatterns (pattern);
41
+ tensor::populateFoldTensorSubsetOpPatterns (pattern);
42
+
43
+ GreedyRewriteConfig config;
44
+ // Use to remove useless tensor operation like extract or
45
+ // insert slice.
46
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
47
+ (void )applyPatternsAndFoldGreedily (getOperation (), std::move (pattern),
48
+ config);
49
+ }
50
+ };
51
+ } // namespace
52
+ } // namespace gc
53
+ } // namespace mlir
0 commit comments