Skip to content

Commit f6af949

Browse files
authored
【CINN】Add simplify mutator for logical op and compare op (#70456)
* add mutator * fix conflict * fix bug * fix ci bug * fix ci bug * fix bug
1 parent 5f4b91c commit f6af949

File tree

4 files changed

+148
-107
lines changed

4 files changed

+148
-107
lines changed

paddle/cinn/common/cas.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ struct CasInterval {
5757
VLOG(6) << "CasInterval is : [" << expr_l << ", " << expr_r << "].";
5858
expr_r = detail::ReplaceMinToConstant(expr_r);
5959
expr_l = detail::ReplaceMaxToConstant(expr_l);
60-
optim::Simplify(&expr_l);
61-
optim::Simplify(&expr_r);
60+
expr_l = optim::ArithSimplify(expr_l);
61+
expr_r = optim::ArithSimplify(expr_r);
6262
VLOG(6) << "After simplify, CasInterval is : [" << expr_l << ", " << expr_r
6363
<< "].";
6464

paddle/cinn/ir/ir_mutator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ExprMutator : public IRMutator<T> {
4646
if (expr->is_index()) return;
4747
IRMutator<T>::Visit(expr, op);
4848
}
49-
void Visit(const IndexExpr *expr, IndexExpr *op) override { return; }
49+
void Visit(const IndexExpr *expr, T op) override { return; }
5050
};
5151

5252
template <typename T>

paddle/cinn/optim/ir_simplify.cc

Lines changed: 142 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "paddle/cinn/ir/ir_visitor.h"
3030
#include "paddle/cinn/ir/op/ir_operators.h"
3131
#include "paddle/cinn/ir/tensor.h"
32+
#include "paddle/cinn/ir/utils/ir_copy.h"
3233
#include "paddle/cinn/utils/string.h"
3334

3435
namespace cinn {
@@ -42,60 +43,15 @@ using utils::Replace;
4243

4344
namespace {
4445

45-
bool TryEmplaceVarIntervals(const For& op,
46-
cinn::common::cas_intervals_t* var_intervals) {
47-
VLOG(4) << "TryEmplaceVarIntervals with min: " << op.min << ", " << op.extent;
48-
auto* min_i = op.min.As<IntImm>();
49-
auto* extent_i = op.extent.As<IntImm>();
50-
// For containing zero Shape case, skip it.
51-
if (extent_i && extent_i->value <= 0) return false;
52-
53-
if (min_i && extent_i) {
54-
var_intervals->emplace(
55-
op.loop_var->name,
56-
cinn::common::CasInterval{min_i->value, extent_i->value - 1});
57-
} else {
58-
var_intervals->emplace(op.loop_var->name,
59-
cinn::common::CasInterval{op.min, op.extent - 1});
60-
}
61-
return true;
62-
}
63-
64-
bool TryEraseVarIntervals(const For& op,
65-
cinn::common::cas_intervals_t* var_intervals) {
66-
auto* min_i = op.min.As<IntImm>();
67-
auto* extent_i = op.extent.As<IntImm>();
68-
const auto& name = op.loop_var->name;
69-
const bool should_erase = min_i && extent_i && var_intervals->count(name);
70-
if (should_erase) {
71-
var_intervals->erase(name);
72-
}
73-
return should_erase;
74-
}
75-
76-
//! Simplify some sub-expression in the `expr`. Due to the simplify strategy
77-
//! just fit several kinds of IR nodes, we partition the original expression to
78-
//! several sub-expression those supported by simplify, and process each of
79-
//! them.
80-
void PartialSimplify(Expr* expr,
81-
const cinn::common::cas_intervals_t& var_intervals = {}) {
82-
*expr = cinn::common::AutoSimplify(*expr, var_intervals);
83-
}
84-
8546
//! Simplify the expression but Load.
8647
struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
87-
cinn::common::cas_intervals_t& var_intervals_;
88-
explicit SimplifyNoPureMathMutator(
89-
cinn::common::cas_intervals_t& var_intervals) // NOLINT
90-
: var_intervals_(var_intervals) {}
91-
9248
void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }
9349

9450
using ir::IRMutator<>::Visit;
9551

9652
#define __(op__) \
9753
void Visit(const op__* op, Expr* expr) override { \
98-
PartialSimplify(expr, var_intervals_); \
54+
*expr = ArithSimplify(*expr); \
9955
}
10056

10157
__(Add)
@@ -105,34 +61,6 @@ struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
10561
__(Min)
10662
__(Max)
10763
#undef __
108-
109-
void Visit(const PolyFor* op, Expr* expr) override {
110-
auto* node = expr->As<ir::PolyFor>();
111-
node->condition =
112-
cinn::common::SolveInequality(op->condition, op->iterator);
113-
114-
Visit(&node->body, &node->body);
115-
}
116-
117-
void Visit(const For* op, Expr* expr) override {
118-
auto* node = expr->As<ir::For>();
119-
Visit(&node->min, &node->min);
120-
Visit(&node->extent, &node->extent);
121-
TryEmplaceVarIntervals(*op, &var_intervals_);
122-
Visit(&node->body, &node->body);
123-
TryEraseVarIntervals(*op, &var_intervals_);
124-
}
125-
126-
void Visit(const _Tensor_* op, Expr* expr) override {
127-
auto* node = expr->As<ir::_Tensor_>();
128-
129-
for (auto& e : node->shape) {
130-
PartialSimplify(&e, var_intervals_);
131-
}
132-
for (auto& e : node->domain) {
133-
PartialSimplify(&e, var_intervals_);
134-
}
135-
}
13664
};
13765

13866
struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
@@ -142,24 +70,18 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
14270
auto* node = op->As<Load>();
14371
for (auto& idx : node->indices) {
14472
if (cinn::common::IsPureMath(idx)) {
145-
PartialSimplify(&idx, var_intervals_);
73+
idx = ArithSimplify(idx);
14674
} else {
147-
SimplifyNoPureMathMutator mutator(var_intervals_);
148-
mutator(&idx);
75+
SimplifyNoPureMathMutator()(&idx);
14976
}
15077
}
15178
}
15279

15380
void Visit(const For* op, Expr* expr) override {
154-
TryEmplaceVarIntervals(*op, &var_intervals_);
15581
auto* node = expr->As<For>();
15682
operator()(&node->body);
15783
operator()(&node->extent);
158-
159-
TryEraseVarIntervals(*op, &var_intervals_);
16084
}
161-
162-
cinn::common::cas_intervals_t var_intervals_;
16385
};
16486

16587
struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
@@ -170,24 +92,18 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
17092

17193
for (auto& idx : node->indices) {
17294
if (cinn::common::IsPureMath(idx)) {
173-
PartialSimplify(&idx, var_intervals_);
95+
idx = ArithSimplify(idx);
17496
} else {
175-
SimplifyNoPureMathMutator mutator(var_intervals_);
176-
mutator(&idx);
97+
SimplifyNoPureMathMutator()(&idx);
17798
}
17899
}
179100
}
180101

181102
void Visit(const For* op, Expr* expr) override {
182-
TryEmplaceVarIntervals(*op, &var_intervals_);
183103
auto* node = expr->As<For>();
184104
operator()(&node->body);
185105
operator()(&node->extent);
186-
187-
TryEraseVarIntervals(*op, &var_intervals_);
188106
}
189-
190-
cinn::common::cas_intervals_t var_intervals_;
191107
};
192108

193109
struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
@@ -204,8 +120,8 @@ struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
204120
cinn::common::IsPureMath(node->stride),
205121
true,
206122
::common::errors::InvalidArgument("node->stride is not a pure math!"));
207-
PartialSimplify(&node->base);
208-
PartialSimplify(&node->stride);
123+
node->base = ArithSimplify(node->base);
124+
node->stride = ArithSimplify(node->stride);
209125
}
210126
// ramp + ramp
211127
void Visit(const Add* op, Expr* expr) override {
@@ -231,7 +147,6 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
231147

232148
void Visit(const IfThenElse* op, Expr* expr) override {
233149
auto* node = expr->As<ir::IfThenElse>();
234-
node->condition = cinn::common::AutoSimplify(node->condition);
235150

236151
auto* condition_int = node->condition.As<ir::IntImm>();
237152
auto* condition_uint = node->condition.As<ir::UIntImm>();
@@ -258,6 +173,122 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
258173
}
259174
};
260175

176+
struct SimplifySelectMutator : public ir::IRMutator<> {
177+
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
178+
179+
using ir::IRMutator<>::Visit;
180+
181+
void Visit(const Select* op, Expr* expr) override {
182+
auto* node = expr->As<ir::Select>();
183+
184+
auto* condition_int = node->condition.As<ir::IntImm>();
185+
auto* condition_uint = node->condition.As<ir::UIntImm>();
186+
187+
// not deterministic
188+
if (!condition_int && !condition_uint) {
189+
Visit(&node->true_value, &node->true_value);
190+
Visit(&node->false_value, &node->false_value);
191+
return;
192+
}
193+
194+
bool value = condition_int ? condition_int->value : condition_uint->value;
195+
if (value) {
196+
*expr = op->true_value;
197+
Visit(expr, expr);
198+
} else {
199+
*expr = op->false_value;
200+
Visit(expr, expr);
201+
}
202+
}
203+
};
204+
205+
struct SimplifyLogicalMutator : public ir::ExprMutator<> {
206+
void operator()(Expr* expr) { ir::ExprMutator<>::Visit(expr, expr); }
207+
208+
#define DEFINE_VISIT_CMP_OP(OpType, Method) \
209+
void Visit(const ir::OpType* op, Expr* expr) override { \
210+
VLOG(7) << "Begin Visit Cmp op: " << *expr; \
211+
auto* node = expr->As<ir::OpType>(); \
212+
ir::ExprMutator<>::Visit(&node->a(), &node->a()); \
213+
ir::ExprMutator<>::Visit(&node->b(), &node->b()); \
214+
if (node->a().is_constant() && node->b().is_constant()) \
215+
if (node->a().get_constant() Method node->b().get_constant()) \
216+
*expr = Expr(true); \
217+
VLOG(7) << "End Visit Cmp op: " << *expr; \
218+
}
219+
DEFINE_VISIT_CMP_OP(LE, <=)
220+
DEFINE_VISIT_CMP_OP(LT, <)
221+
DEFINE_VISIT_CMP_OP(GE, >=)
222+
DEFINE_VISIT_CMP_OP(GT, >)
223+
DEFINE_VISIT_CMP_OP(EQ, ==)
224+
DEFINE_VISIT_CMP_OP(NE, !=)
225+
226+
#undef DEFINE_VISIT_CMP_OP
227+
228+
void Visit(const ir::And* op, Expr* expr) override {
229+
VLOG(7) << "Begin Visit And op: " << *expr;
230+
auto* node = expr->As<ir::And>();
231+
ir::ExprMutator<>::Visit(&node->a(), &node->a());
232+
if (common::IsZero(node->a())) {
233+
*expr = Expr(false);
234+
VLOG(7) << "End Visit And op: " << *expr;
235+
return;
236+
}
237+
ir::ExprMutator<>::Visit(&node->b(), &node->b());
238+
if (common::IsZero(node->b())) {
239+
VLOG(7) << "End Visit And op: " << *expr;
240+
*expr = Expr(false);
241+
return;
242+
}
243+
if (common::IsOne(node->a()) && common::IsOne(node->b()))
244+
*expr = Expr(true);
245+
VLOG(7) << "End Visit And op: " << *expr;
246+
}
247+
248+
void Visit(const ir::Or* op, Expr* expr) override {
249+
VLOG(7) << "Begin Visit Or op: " << *expr;
250+
auto* node = expr->As<ir::Or>();
251+
ir::ExprMutator<>::Visit(&node->a(), &node->a());
252+
if (common::IsOne(node->a())) {
253+
*expr = Expr(true);
254+
VLOG(7) << "End visit Or op: " << *expr;
255+
return;
256+
}
257+
ir::ExprMutator<>::Visit(&node->b(), &node->b());
258+
if (common::IsOne(node->b())) {
259+
*expr = Expr(true);
260+
VLOG(7) << "End visit Or op: " << *expr;
261+
return;
262+
}
263+
if (common::IsZero(node->a()) && common::IsZero(node->b()))
264+
*expr = Expr(false);
265+
VLOG(7) << "End visit Or op: " << *expr;
266+
}
267+
268+
void Visit(const ir::Not* op, Expr* expr) override {
269+
auto* node = expr->As<ir::Not>();
270+
auto v = node->v();
271+
ir::ExprMutator<>::Visit(&v, &v);
272+
switch (v.node_type()) {
273+
case ir::IrNodeTy::IntImm:
274+
case ir::IrNodeTy::UIntImm:
275+
*expr = common::IsZero(v) ? Expr(true) : Expr(false);
276+
case ir::IrNodeTy::Not:
277+
*expr = v.As<ir::Not>()->v();
278+
case ir::IrNodeTy::LE:
279+
*expr = ir::GT::Make(v->operand(0), v->operand(1));
280+
case ir::IrNodeTy::LT:
281+
*expr = ir::GE::Make(v->operand(0), v->operand(1));
282+
case ir::IrNodeTy::GE:
283+
*expr = ir::LT::Make(v->operand(0), v->operand(1));
284+
case ir::IrNodeTy::GT:
285+
*expr = ir::LE::Make(v->operand(0), v->operand(1));
286+
default:
287+
return;
288+
}
289+
}
290+
};
291+
261292
struct ReplaceFracWithDivMutator : public ir::IRMutator<> {
262293
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
263294

@@ -461,25 +492,32 @@ struct SimplifyCastMutator : public ir::IRMutator<> {
461492

462493
} // namespace
463494

495+
void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
496+
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
497+
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }
498+
499+
void SimplifyLogical(Expr* expr) { SimplifyLogicalMutator()(expr); }
500+
501+
Expr ArithSimplify(const Expr& u) {
502+
if (!u.is_index()) return u;
503+
auto copied = ir_utils::IRCopy(u);
504+
return copied.as_index().Normalize();
505+
}
506+
464507
void Simplify(Expr* expr) {
465-
VLOG(3) << "Begin Simplify " << *expr;
508+
VLOG(6) << "Begin Simplify " << *expr;
509+
SimplifyNoPureMathMutator()(expr);
466510
SimplifyCastMutator()(expr);
467511
SimplifyRampMutator()(expr);
468512
SimplifyLoadMutator()(expr);
469513
SimplifyStoreMutator()(expr);
514+
SimplifyLogicalMutator()(expr);
470515
SimplifyIfThenElseMutator()(expr);
471-
472-
cinn::common::cas_intervals_t var_intervals;
473-
SimplifyNoPureMathMutator mutator(var_intervals);
474-
mutator(expr);
516+
SimplifySelectMutator()(expr);
517+
SimplifyNoPureMathMutator()(expr);
475518

476519
ReplaceFracWithDivMutator()(expr);
477-
VLOG(3) << "End Simplify " << *expr;
520+
VLOG(6) << "End Simplify " << *expr;
478521
}
479-
480-
void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
481-
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
482-
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }
483-
484522
} // namespace optim
485523
} // namespace cinn

paddle/cinn/optim/ir_simplify.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,8 @@ void SimplifyForLoops(Expr *expr);
183183
*/
184184
void SimplifyBlocks(Expr *expr);
185185

186+
void SimplifyLogical(Expr *expr);
187+
188+
Expr ArithSimplify(const Expr &u);
186189
} // namespace optim
187190
} // namespace cinn

0 commit comments

Comments
 (0)