Skip to content

Commit 825179b

Browse files
committed
refactor the pass tutor opt
1 parent e31965d commit 825179b

File tree

14 files changed

+417
-20
lines changed

14 files changed

+417
-20
lines changed

mlir/mhlo-pass-example/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ add_subdirectory(Outline)
4949
add_subdirectory(Pow2)
5050
add_subdirectory(Tanh)
5151
add_subdirectory(tests)
52-
add_subdirectory(pass-tutor-opt)
52+
add_subdirectory(PassTutor)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
include_directories(include)
2+
add_subdirectory(include)
3+
4+
set(LLVM_LINK_COMPONENTS Support)
5+
6+
set(LLVM_TARGET_DEFINITIONS mlir/Pow2.td)
7+
mlir_tablegen(Pow2.inc -gen-rewriters EXTRA_INCLUDES ${MLIR_INCLUDE_DIR})
8+
add_public_tablegen_target(PassTutorIncGen)
9+
10+
set(LLVM_TARGET_DEFINITIONS mlir/transform/pass.td)
11+
mlir_tablegen(Pow2Pass.inc -gen-pass-decls EXTRA_INCLUDES ${MLIR_INCLUDE_DIR})
12+
mlir_tablegen(Pow2Pass.md -gen-pass-doc)
13+
add_public_tablegen_target(PassTutorPassIncGen)
14+
15+
add_mlir_pdll_library(PassTutorPdllIncGen mlir/Pow2.pdll Pow2Pdll.inc
16+
EXTRA_INCLUDES ${MLIR_INCLUDE_DIR})
17+
18+
add_executable(pass-tutor-opt main.cpp mlir/PassTutor.cpp)
19+
20+
add_dependencies(pass-tutor-opt PassTutorIncGen PassTutorPdllIncGen
21+
PassTutorPassIncGen)
22+
23+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
24+
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
25+
26+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
27+
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
28+
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
29+
30+
target_link_libraries(
31+
pass-tutor-opt PRIVATE ${dialect_libs} ${conversion_libs} ${extension_libs}
32+
MLIRIR MLIRMlirOptMain MhloDialect)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(passes)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# empty
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef POW2_H
2+
#define POW2_H
3+
4+
#include <memory>
5+
6+
namespace mlir {
7+
8+
class Pass;
9+
#define GEN_PASS_DECL_POW2PASS
10+
#include "Pow2Pass.inc"
11+
12+
namespace mhlo {
13+
14+
std::unique_ptr<Pass> createSubstitutePow2Pass();
15+
std::unique_ptr<Pass> createSubstitutePow2Pass(const Pow2PassOptions &options);
16+
std::unique_ptr<mlir::Pass> createStaticOpCounter();
17+
18+
#define GEN_PASS_REGISTRATION
19+
#include "Pow2Pass.inc"
20+
21+
} // namespace mhlo
22+
} // namespace mlir
23+
24+
#endif // POW2_H
File renamed without changes.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module {
2+
func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
3+
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
4+
%1 = mhlo.constant dense<2.0> : tensor<2x2xf32>
5+
%2 = "mhlo.power"(%0, %1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
6+
func.return %2 : tensor<2x2xf32>
7+
}
8+
}
9+
10+
// after conversion:
11+
// module {
12+
// func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
13+
// %0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
14+
// %2 = "mhlo.multiply"(%0, %0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
15+
// func.return %2 : tensor<2x2xf32>
16+
// }
17+
// }
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#include <mhlo/IR/hlo_ops.h>
2+
#include <mlir/Dialect/Func/IR/FuncOps.h>
3+
#include <mlir/IR/BuiltinDialect.h>
4+
#include <mlir/IR/Matchers.h>
5+
#include <mlir/IR/PatternMatch.h>
6+
#include <mlir/Pass/Pass.h>
7+
#include <mlir/Transforms/DialectConversion.h>
8+
9+
#include <llvm/Support/Format.h>
10+
#include <llvm/Support/raw_ostream.h>
11+
#include <mlir/Analysis/CallGraph.h>
12+
#include <mlir/IR/Action.h>
13+
#include <mlir/IR/Builders.h>
14+
#include <mlir/IR/MLIRContext.h>
15+
#include <mlir/IR/PatternMatch.h>
16+
#include <mlir/Parser/Parser.h>
17+
#include <mlir/Pass/PassManager.h>
18+
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
19+
#include <mlir/Transforms/Passes.h>
20+
21+
#include "passes/Pow2.h"
22+
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/Debug.h"
24+
25+
#include <memory>
26+
#include <numeric>
27+
28+
using namespace mlir;
29+
30+
namespace {
31+
// We add MLIR actions here as an example.
32+
/// A custom Action can be defined minimally by deriving from
33+
/// `tracing::ActionImpl`. The action is same as the pass declaration with tddr
34+
/// rules. only for `xxx-opt` binary. and run with `--log-actions-to=-` to dump
35+
/// the actions.
36+
class EchoAction : public tracing::ActionImpl<EchoAction> {
37+
public:
38+
using Base = tracing::ActionImpl<EchoAction>;
39+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EchoAction)
40+
41+
/// Actions are initialized with an array of IRUnit (that is either Operation,
42+
/// Block, or Region) that provide context for the IR affected by a
43+
/// transformation.
44+
EchoAction(ArrayRef<IRUnit> irUnits, int iteration)
45+
: Base(irUnits), iteration(iteration) {}
46+
/// This tag should uniquely identify this action, it can be matched for
47+
/// filtering during processing.
48+
static constexpr StringLiteral tag = "echo-action";
49+
static constexpr StringLiteral desc = "Just echo the iteration";
50+
51+
void print(raw_ostream &os) const override {
52+
os << "EchoAction: " << iteration << "\n";
53+
}
54+
55+
private:
56+
int iteration;
57+
};
58+
} // namespace
59+
60+
namespace {
61+
bool ValueEql2(Value operand) {
62+
FloatAttr::ValueType FValue = FloatAttr::ValueType(2.0);
63+
if (matchPattern(operand, m_ConstantFloat(&FValue))) {
64+
if (FValue.convertToFloat() == 2.0) {
65+
return true;
66+
}
67+
}
68+
return false;
69+
}
70+
71+
static LogicalResult Eqn2Impl(PatternRewriter &rewriter, Value value) {
72+
return success(ValueEql2(value));
73+
}
74+
75+
} // namespace
76+
77+
void registerNativeConstraints(RewritePatternSet &patterns) {
78+
patterns.getPDLPatterns().registerConstraintFunction("Eqn2", Eqn2Impl);
79+
}
80+
81+
namespace {
82+
/// Include the patterns defined in the Declarative Rewrite framework.
83+
#include "Pow2.inc"
84+
#include "Pow2Pdll.inc"
85+
} // namespace
86+
87+
namespace {
88+
struct SubstitutePow2Pass
89+
: public PassWrapper<SubstitutePow2Pass, OperationPass<func::FuncOp>> {
90+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SubstitutePow2Pass)
91+
92+
void runOnOperation() final;
93+
};
94+
} // namespace
95+
96+
void SubstitutePow2Pass::runOnOperation() {
97+
auto op = getOperation();
98+
RewritePatternSet patterns(&getContext());
99+
patterns.add<Pow2OptPattern>(&getContext());
100+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
101+
signalPassFailure();
102+
}
103+
104+
namespace {
105+
struct SubstitutePow2PdllPass
106+
: public PassWrapper<SubstitutePow2PdllPass, OperationPass<func::FuncOp>> {
107+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SubstitutePow2PdllPass)
108+
109+
void runOnOperation() final;
110+
};
111+
} // namespace
112+
113+
void SubstitutePow2PdllPass::runOnOperation() {
114+
auto op = getOperation();
115+
RewritePatternSet patterns(&getContext());
116+
// --- insert the native constraints ---
117+
registerNativeConstraints(patterns);
118+
// --- insert the native constraints ---
119+
patterns.add<Pow2PdllOptPattern>(&getContext());
120+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
121+
signalPassFailure();
122+
}
123+
124+
namespace {
125+
#define GEN_PASS_DEF_POW2PASS
126+
#include "Pow2Pass.inc"
127+
} // namespace
128+
129+
namespace {
130+
struct SubstitutePow2PdllGenPass
131+
: impl::Pow2PassBase<SubstitutePow2PdllGenPass> {
132+
using Pow2PassBase::Pow2PassBase;
133+
134+
void runOnOperation() final {
135+
auto op = getOperation();
136+
MLIRContext *context = &getContext();
137+
RewritePatternSet patterns(&*context);
138+
// --- insert the native constraints ---
139+
registerNativeConstraints(patterns);
140+
// --- insert the native constraints ---
141+
patterns.add<Pow2PdllOptPattern>(&*context);
142+
// Here, we wrap the applyPatternsAndFoldGreedily in a lambda function and
143+
// pass it to the MLIR Action.
144+
Operation *opp = getOperation();
145+
ArrayRef<IRUnit> irUnits{opp};
146+
context->executeAction<EchoAction>(
147+
[&]() {
148+
// Here is the pass body.
149+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
150+
signalPassFailure();
151+
},
152+
/*irUnits=*/irUnits, /*iteration=*/10);
153+
// Above, we pass the irUnits and iteration to the EchoAction.
154+
statistic++;
155+
};
156+
};
157+
} // namespace
158+
159+
std::unique_ptr<mlir::Pass> mhlo::createSubstitutePow2Pass() {
160+
// There are 2 methods to achieve the same goal:
161+
// 1. use the tddr rules to rewrite the IR
162+
// return std::make_unique<SubstitutePow2Pass>();
163+
// 2. use the pdll to rewrite the IR
164+
// return std::make_unique<SubstitutePow2PdllPass>();
165+
// 3. use tddr to generate pass declaration.
166+
return std::make_unique<SubstitutePow2PdllGenPass>();
167+
}
168+
169+
/// An interesting analysis.
170+
struct StaticOpCounterAnalysis {
171+
llvm::StringMap<int> opCount;
172+
// Compute this analysis with the provided operation.
173+
StaticOpCounterAnalysis(Operation *op) : opCount({}){};
174+
175+
void add(Operation *op) {
176+
auto opName = op->getName().getStringRef();
177+
opCount.find(opName) == opCount.end()
178+
? opCount[opName] = 1
179+
: opCount[opName] = opCount[opName] + 1;
180+
}
181+
182+
llvm::StringMap<int> getOpCount() const { return opCount; };
183+
};
184+
185+
struct StaticOpCounterAnalysisWithDependency {
186+
StaticOpCounterAnalysisWithDependency(Operation *op, AnalysisManager &am) {
187+
// Request other analysis as dependency
188+
StaticOpCounterAnalysis &otherAnalysis =
189+
am.getAnalysis<StaticOpCounterAnalysis>();
190+
}
191+
192+
bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
193+
// Check if analysis or its dependency were invalidated
194+
return !pa.isPreserved<StaticOpCounterAnalysisWithDependency>() ||
195+
!pa.isPreserved<StaticOpCounterAnalysis>();
196+
}
197+
};
198+
199+
namespace {
200+
struct StaticOpCounter
201+
: public PassWrapper<StaticOpCounter, OperationPass<func::FuncOp>> {
202+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticOpCounter)
203+
204+
void runOnOperation() final;
205+
};
206+
} // namespace
207+
208+
void StaticOpCounter::runOnOperation() {
209+
StaticOpCounterAnalysis &myAnalysis = getAnalysis<StaticOpCounterAnalysis>();
210+
auto module_op = getOperation();
211+
for (auto &op : module_op.getOps()) {
212+
myAnalysis.add(&op);
213+
}
214+
215+
const char *str1 = "NAME";
216+
const char *str2 = "#N DIRECT CALLS";
217+
218+
llvm::dbgs() << "================================================="
219+
<< "\n";
220+
llvm::dbgs() << "MLIR-PASS-TUTOR: static analysis results\n";
221+
llvm::dbgs() << "=================================================\n";
222+
llvm::dbgs() << llvm::format("%-20s %-10s\n", str1, str2);
223+
llvm::dbgs() << "-------------------------------------------------"
224+
<< "\n";
225+
for (auto &CallCount : myAnalysis.getOpCount()) {
226+
llvm::dbgs() << llvm::format("%-20s %-10lu\n",
227+
CallCount.first().str().c_str(),
228+
CallCount.getValue());
229+
}
230+
231+
llvm::dbgs() << "-------------------------------------------------"
232+
<< "\n\n";
233+
}
234+
235+
std::unique_ptr<mlir::Pass> mhlo::createStaticOpCounter() {
236+
// There are 2 methods to achieve the same goal:
237+
// 1. use the tddr rules to rewrite the IR
238+
// return std::make_unique<SubstitutePow2Pass>();
239+
// 2. use the pdll to rewrite the IR
240+
// return std::make_unique<SubstitutePow2PdllPass>();
241+
// 3. use tddr to generate pass declaration.
242+
return std::make_unique<StaticOpCounter>();
243+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "mlir/IR/PatternBase.td"
2+
#include "mhlo/IR/hlo_ops.td"
3+
4+
// The first way to replace
5+
// Pattern Pow2PdllOptPattern with benefit(0) {
6+
// // ** match section ** //
7+
// let root = op<mhlo.power>( arg :Value,
8+
// op<mhlo.constant>() {
9+
// value = attr<"dense<2.0> : tensor<2x2xf32>">});
10+
11+
// // ** rewrite section ** //
12+
// replace root with op<mhlo.multiply>(arg, arg);
13+
// }
14+
15+
// Add Custom Constraints to rewrite pattern
16+
Constraint Eqn2(value: Value);
17+
18+
Constraint TypesAreIdentical(value1: Value, value2: Value)[{
19+
return success(value1.getType() == value2.getType());
20+
}];
21+
22+
Pattern Pow2PdllOptPattern with benefit(0) {
23+
// ** match section ** //
24+
let const_2 : Value = op<mhlo.constant>();
25+
let arg : Value;
26+
TypesAreIdentical(arg, const_2);
27+
Eqn2(const_2);
28+
let root = op<mhlo.power>(arg, const_2);
29+
replace root with op<mhlo.multiply>(arg, arg);
30+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines language-specific pattern match optimizations for Toy using
10+
// Declarative Rewrite Rules (DRR) specified using TableGen records.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef POW_2
15+
#define POW_2
16+
17+
include "mlir/IR/PatternBase.td"
18+
include "mhlo/IR/hlo_ops.td"
19+
20+
/// Note: The DRR definition used for defining patterns is shown below:
21+
///
22+
/// class Pattern<
23+
/// dag sourcePattern, list<dag> resultPatterns,
24+
/// list<dag> additionalConstraints = [],
25+
/// dag benefitsAdded = (addBenefit 0)
26+
/// >;
27+
28+
29+
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
30+
31+
def Eqn2 : Constraint<CPred<"::ValueEql2($0)">>;
32+
33+
def Pow2OptPattern : Pat<(MHLO_PowOp $arg, (MHLO_ConstantOp:$cst $cstVal)),
34+
(MHLO_MulOp $arg, $arg),
35+
[(TypesAreIdentical $arg, $cst),
36+
(Eqn2 $cst)]>;
37+
38+
#endif // POW_2

0 commit comments

Comments
 (0)