Skip to content

[Transform][Vector] Lower operation to tile (virtual) vector #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,29 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
let dependentDialects = ["scf::SCFDialect"];
}

def FoldTensorOperation : Pass<"fold-tensor-operation"> {
let summary = "Fold some tensor operation";
let description = [{
Remove some useless tensor operations.
}];
let dependentDialects = [
"::mlir::tensor::TensorDialect"
];
}

def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> {
let summary = "Lower tensor to tile (virtual) vector";
let description = [{
Lower operation operate on tensor to vector operation.
}];
let dependentDialects = [
"::mlir::func::FuncDialect",
"::mlir::math::MathDialect",
"::mlir::arith::ArithDialect",
"::mlir::tensor::TensorDialect",
"::mlir::linalg::LinalgDialect",
"::mlir::vector::VectorDialect",
];
}

#endif // GC_DIALECT_GC_PASSES
2 changes: 2 additions & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ gc_add_mlir_library(GcPasses
TilingUtil.cpp
SinkOpIntoInnerLoop.cpp
MergeNestedForall.cpp
FoldTensorOperation.cpp
LowerToTileVector.cpp

DEPENDS
GraphCompilerPassIncGen
Expand Down
53 changes: 53 additions & 0 deletions lib/gc/Transforms/FoldTensorOperation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===-- FoldTensorOperation.cpp - fold tensor op ----------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "gc/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"

namespace mlir {
namespace gc {

#define GEN_PASS_DEF_FOLDTENSOROPERATION
#include "gc/Transforms/Passes.h.inc"
namespace {

/// FoldTensorOperationPass is a pass that fold some useless tensor
/// operation.
struct FoldTensorOperationPass
: public impl::FoldTensorOperationBase<FoldTensorOperationPass> {
void runOnOperation() final {
//
auto *ctx = &getContext();
RewritePatternSet pattern(ctx);

tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
return producer && producer->hasOneUse();
};
// Some operation convert as constant, this pattern can help us to improve
// the performance.
tensor::populateRewriteAsConstantPatterns(pattern, defaultControlFn);
// Remove unnessary operation like extract slice and insert slice
tensor::populateReassociativeReshapeFoldingPatterns(pattern);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(pattern);
tensor::populateFoldTensorSubsetOpPatterns(pattern);

GreedyRewriteConfig config;
// Use to remove useless tensor operation like extract or
// insert slice.
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
config);
}
};
} // namespace
} // namespace gc
} // namespace mlir
Loading