29
29
#include " paddle/cinn/ir/ir_visitor.h"
30
30
#include " paddle/cinn/ir/op/ir_operators.h"
31
31
#include " paddle/cinn/ir/tensor.h"
32
+ #include " paddle/cinn/ir/utils/ir_copy.h"
32
33
#include " paddle/cinn/utils/string.h"
33
34
34
35
namespace cinn {
@@ -42,60 +43,15 @@ using utils::Replace;
42
43
43
44
namespace {
44
45
45
- bool TryEmplaceVarIntervals (const For& op,
46
- cinn::common::cas_intervals_t * var_intervals) {
47
- VLOG (4 ) << " TryEmplaceVarIntervals with min: " << op.min << " , " << op.extent ;
48
- auto * min_i = op.min .As <IntImm>();
49
- auto * extent_i = op.extent .As <IntImm>();
50
- // For containing zero Shape case, skip it.
51
- if (extent_i && extent_i->value <= 0 ) return false ;
52
-
53
- if (min_i && extent_i) {
54
- var_intervals->emplace (
55
- op.loop_var ->name ,
56
- cinn::common::CasInterval{min_i->value , extent_i->value - 1 });
57
- } else {
58
- var_intervals->emplace (op.loop_var ->name ,
59
- cinn::common::CasInterval{op.min , op.extent - 1 });
60
- }
61
- return true ;
62
- }
63
-
64
- bool TryEraseVarIntervals (const For& op,
65
- cinn::common::cas_intervals_t * var_intervals) {
66
- auto * min_i = op.min .As <IntImm>();
67
- auto * extent_i = op.extent .As <IntImm>();
68
- const auto & name = op.loop_var ->name ;
69
- const bool should_erase = min_i && extent_i && var_intervals->count (name);
70
- if (should_erase) {
71
- var_intervals->erase (name);
72
- }
73
- return should_erase;
74
- }
75
-
76
- // ! Simplify some sub-expression in the `expr`. Due to the simplify strategy
77
- // ! just fit several kinds of IR nodes, we partition the original expression to
78
- // ! several sub-expression those supported by simplify, and process each of
79
- // ! them.
80
- void PartialSimplify (Expr* expr,
81
- const cinn::common::cas_intervals_t & var_intervals = {}) {
82
- *expr = cinn::common::AutoSimplify (*expr, var_intervals);
83
- }
84
-
85
46
// ! Simplify the expression but Load.
86
47
struct SimplifyNoPureMathMutator : public ir ::IRMutator<ir::Expr*> {
87
- cinn::common::cas_intervals_t & var_intervals_;
88
- explicit SimplifyNoPureMathMutator (
89
- cinn::common::cas_intervals_t & var_intervals) // NOLINT
90
- : var_intervals_(var_intervals) {}
91
-
92
48
void operator ()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit (x, x); }
93
49
94
50
using ir::IRMutator<>::Visit;
95
51
96
52
#define __ (op__ ) \
97
53
void Visit (const op__* op, Expr* expr) override { \
98
- PartialSimplify ( expr, var_intervals_); \
54
+ * expr = ArithSimplify (*expr); \
99
55
}
100
56
101
57
__ (Add)
@@ -105,34 +61,6 @@ struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
105
61
__ (Min)
106
62
__ (Max)
107
63
#undef __
108
-
109
- void Visit (const PolyFor* op, Expr* expr) override {
110
- auto * node = expr->As <ir::PolyFor>();
111
- node->condition =
112
- cinn::common::SolveInequality (op->condition , op->iterator );
113
-
114
- Visit (&node->body , &node->body );
115
- }
116
-
117
- void Visit (const For* op, Expr* expr) override {
118
- auto * node = expr->As <ir::For>();
119
- Visit (&node->min , &node->min );
120
- Visit (&node->extent , &node->extent );
121
- TryEmplaceVarIntervals (*op, &var_intervals_);
122
- Visit (&node->body , &node->body );
123
- TryEraseVarIntervals (*op, &var_intervals_);
124
- }
125
-
126
- void Visit (const _Tensor_* op, Expr* expr) override {
127
- auto * node = expr->As <ir::_Tensor_>();
128
-
129
- for (auto & e : node->shape ) {
130
- PartialSimplify (&e, var_intervals_);
131
- }
132
- for (auto & e : node->domain ) {
133
- PartialSimplify (&e, var_intervals_);
134
- }
135
- }
136
64
};
137
65
138
66
struct SimplifyLoadMutator : public ir ::IRMutator<ir::Expr*> {
@@ -142,24 +70,18 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
142
70
auto * node = op->As <Load>();
143
71
for (auto & idx : node->indices ) {
144
72
if (cinn::common::IsPureMath (idx)) {
145
- PartialSimplify (& idx, var_intervals_ );
73
+ idx = ArithSimplify (idx );
146
74
} else {
147
- SimplifyNoPureMathMutator mutator (var_intervals_);
148
- mutator (&idx);
75
+ SimplifyNoPureMathMutator ()(&idx);
149
76
}
150
77
}
151
78
}
152
79
153
80
void Visit (const For* op, Expr* expr) override {
154
- TryEmplaceVarIntervals (*op, &var_intervals_);
155
81
auto * node = expr->As <For>();
156
82
operator ()(&node->body );
157
83
operator ()(&node->extent );
158
-
159
- TryEraseVarIntervals (*op, &var_intervals_);
160
84
}
161
-
162
- cinn::common::cas_intervals_t var_intervals_;
163
85
};
164
86
165
87
struct SimplifyStoreMutator : public ir ::IRMutator<ir::Expr*> {
@@ -170,24 +92,18 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
170
92
171
93
for (auto & idx : node->indices ) {
172
94
if (cinn::common::IsPureMath (idx)) {
173
- PartialSimplify (& idx, var_intervals_ );
95
+ idx = ArithSimplify (idx );
174
96
} else {
175
- SimplifyNoPureMathMutator mutator (var_intervals_);
176
- mutator (&idx);
97
+ SimplifyNoPureMathMutator ()(&idx);
177
98
}
178
99
}
179
100
}
180
101
181
102
void Visit (const For* op, Expr* expr) override {
182
- TryEmplaceVarIntervals (*op, &var_intervals_);
183
103
auto * node = expr->As <For>();
184
104
operator ()(&node->body );
185
105
operator ()(&node->extent );
186
-
187
- TryEraseVarIntervals (*op, &var_intervals_);
188
106
}
189
-
190
- cinn::common::cas_intervals_t var_intervals_;
191
107
};
192
108
193
109
struct SimplifyRampMutator : public ir ::IRMutator<Expr*> {
@@ -204,8 +120,8 @@ struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
204
120
cinn::common::IsPureMath (node->stride ),
205
121
true ,
206
122
::common::errors::InvalidArgument (" node->stride is not a pure math!" ));
207
- PartialSimplify (& node->base );
208
- PartialSimplify (& node->stride );
123
+ node-> base = ArithSimplify ( node->base );
124
+ node-> stride = ArithSimplify ( node->stride );
209
125
}
210
126
// ramp + ramp
211
127
void Visit (const Add* op, Expr* expr) override {
@@ -231,7 +147,6 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
231
147
232
148
void Visit (const IfThenElse* op, Expr* expr) override {
233
149
auto * node = expr->As <ir::IfThenElse>();
234
- node->condition = cinn::common::AutoSimplify (node->condition );
235
150
236
151
auto * condition_int = node->condition .As <ir::IntImm>();
237
152
auto * condition_uint = node->condition .As <ir::UIntImm>();
@@ -258,6 +173,122 @@ struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
258
173
}
259
174
};
260
175
176
+ struct SimplifySelectMutator : public ir ::IRMutator<> {
177
+ void operator ()(Expr* x) { ir::IRMutator<>::Visit (x, x); }
178
+
179
+ using ir::IRMutator<>::Visit;
180
+
181
+ void Visit (const Select* op, Expr* expr) override {
182
+ auto * node = expr->As <ir::Select>();
183
+
184
+ auto * condition_int = node->condition .As <ir::IntImm>();
185
+ auto * condition_uint = node->condition .As <ir::UIntImm>();
186
+
187
+ // not deterministic
188
+ if (!condition_int && !condition_uint) {
189
+ Visit (&node->true_value , &node->true_value );
190
+ Visit (&node->false_value , &node->false_value );
191
+ return ;
192
+ }
193
+
194
+ bool value = condition_int ? condition_int->value : condition_uint->value ;
195
+ if (value) {
196
+ *expr = op->true_value ;
197
+ Visit (expr, expr);
198
+ } else {
199
+ *expr = op->false_value ;
200
+ Visit (expr, expr);
201
+ }
202
+ }
203
+ };
204
+
205
+ struct SimplifyLogicalMutator : public ir ::ExprMutator<> {
206
+ void operator ()(Expr* expr) { ir::ExprMutator<>::Visit (expr, expr); }
207
+
208
+ #define DEFINE_VISIT_CMP_OP (OpType, Method ) \
209
+ void Visit (const ir::OpType* op, Expr* expr) override { \
210
+ VLOG (7 ) << " Begin Visit Cmp op: " << *expr; \
211
+ auto * node = expr->As <ir::OpType>(); \
212
+ ir::ExprMutator<>::Visit (&node->a (), &node->a ()); \
213
+ ir::ExprMutator<>::Visit (&node->b (), &node->b ()); \
214
+ if (node->a ().is_constant () && node->b ().is_constant ()) \
215
+ if (node->a ().get_constant () Method node->b ().get_constant ()) \
216
+ *expr = Expr (true ); \
217
+ VLOG (7 ) << " End Visit Cmp op: " << *expr; \
218
+ }
219
+ DEFINE_VISIT_CMP_OP (LE, <=)
220
+ DEFINE_VISIT_CMP_OP (LT, <)
221
+ DEFINE_VISIT_CMP_OP (GE, >=)
222
+ DEFINE_VISIT_CMP_OP (GT, >)
223
+ DEFINE_VISIT_CMP_OP (EQ, ==)
224
+ DEFINE_VISIT_CMP_OP (NE, !=)
225
+
226
+ #undef DEFINE_VISIT_CMP_OP
227
+
228
+ void Visit (const ir::And* op, Expr* expr) override {
229
+ VLOG (7 ) << " Begin Visit And op: " << *expr;
230
+ auto * node = expr->As <ir::And>();
231
+ ir::ExprMutator<>::Visit (&node->a (), &node->a ());
232
+ if (common::IsZero (node->a ())) {
233
+ *expr = Expr (false );
234
+ VLOG (7 ) << " End Visit And op: " << *expr;
235
+ return ;
236
+ }
237
+ ir::ExprMutator<>::Visit (&node->b (), &node->b ());
238
+ if (common::IsZero (node->b ())) {
239
+ VLOG (7 ) << " End Visit And op: " << *expr;
240
+ *expr = Expr (false );
241
+ return ;
242
+ }
243
+ if (common::IsOne (node->a ()) && common::IsOne (node->b ()))
244
+ *expr = Expr (true );
245
+ VLOG (7 ) << " End Visit And op: " << *expr;
246
+ }
247
+
248
+ void Visit (const ir::Or* op, Expr* expr) override {
249
+ VLOG (7 ) << " Begin Visit Or op: " << *expr;
250
+ auto * node = expr->As <ir::Or>();
251
+ ir::ExprMutator<>::Visit (&node->a (), &node->a ());
252
+ if (common::IsOne (node->a ())) {
253
+ *expr = Expr (true );
254
+ VLOG (7 ) << " End visit Or op: " << *expr;
255
+ return ;
256
+ }
257
+ ir::ExprMutator<>::Visit (&node->b (), &node->b ());
258
+ if (common::IsOne (node->b ())) {
259
+ *expr = Expr (true );
260
+ VLOG (7 ) << " End visit Or op: " << *expr;
261
+ return ;
262
+ }
263
+ if (common::IsZero (node->a ()) && common::IsZero (node->b ()))
264
+ *expr = Expr (false );
265
+ VLOG (7 ) << " End visit Or op: " << *expr;
266
+ }
267
+
268
+ void Visit (const ir::Not* op, Expr* expr) override {
269
+ auto * node = expr->As <ir::Not>();
270
+ auto v = node->v ();
271
+ ir::ExprMutator<>::Visit (&v, &v);
272
+ switch (v.node_type ()) {
273
+ case ir::IrNodeTy::IntImm:
274
+ case ir::IrNodeTy::UIntImm:
275
+ *expr = common::IsZero (v) ? Expr (true ) : Expr (false );
276
+ case ir::IrNodeTy::Not:
277
+ *expr = v.As <ir::Not>()->v ();
278
+ case ir::IrNodeTy::LE:
279
+ *expr = ir::GT::Make (v->operand (0 ), v->operand (1 ));
280
+ case ir::IrNodeTy::LT:
281
+ *expr = ir::GE::Make (v->operand (0 ), v->operand (1 ));
282
+ case ir::IrNodeTy::GE:
283
+ *expr = ir::LT::Make (v->operand (0 ), v->operand (1 ));
284
+ case ir::IrNodeTy::GT:
285
+ *expr = ir::LE::Make (v->operand (0 ), v->operand (1 ));
286
+ default :
287
+ return ;
288
+ }
289
+ }
290
+ };
291
+
261
292
struct ReplaceFracWithDivMutator : public ir ::IRMutator<> {
262
293
void operator ()(Expr* x) { ir::IRMutator<>::Visit (x, x); }
263
294
@@ -461,25 +492,32 @@ struct SimplifyCastMutator : public ir::IRMutator<> {
461
492
462
493
} // namespace
463
494
495
+ void SimplifyCast (Expr* expr) { SimplifyCastMutator ()(expr); }
496
+ void SimplifyForLoops (Expr* expr) { SimplifyForLoopsMutator ()(expr); }
497
+ void SimplifyBlocks (Expr* expr) { SimplifyBlocksMutator ()(expr); }
498
+
499
+ void SimplifyLogical (Expr* expr) { SimplifyLogicalMutator ()(expr); }
500
+
501
+ Expr ArithSimplify (const Expr& u) {
502
+ if (!u.is_index ()) return u;
503
+ auto copied = ir_utils::IRCopy (u);
504
+ return copied.as_index ().Normalize ();
505
+ }
506
+
464
507
void Simplify (Expr* expr) {
465
- VLOG (3 ) << " Begin Simplify " << *expr;
508
+ VLOG (6 ) << " Begin Simplify " << *expr;
509
+ SimplifyNoPureMathMutator ()(expr);
466
510
SimplifyCastMutator ()(expr);
467
511
SimplifyRampMutator ()(expr);
468
512
SimplifyLoadMutator ()(expr);
469
513
SimplifyStoreMutator ()(expr);
514
+ SimplifyLogicalMutator ()(expr);
470
515
SimplifyIfThenElseMutator ()(expr);
471
-
472
- cinn::common::cas_intervals_t var_intervals;
473
- SimplifyNoPureMathMutator mutator (var_intervals);
474
- mutator (expr);
516
+ SimplifySelectMutator ()(expr);
517
+ SimplifyNoPureMathMutator ()(expr);
475
518
476
519
ReplaceFracWithDivMutator ()(expr);
477
- VLOG (3 ) << " End Simplify " << *expr;
520
+ VLOG (6 ) << " End Simplify " << *expr;
478
521
}
479
-
480
- void SimplifyCast (Expr* expr) { SimplifyCastMutator ()(expr); }
481
- void SimplifyForLoops (Expr* expr) { SimplifyForLoopsMutator ()(expr); }
482
- void SimplifyBlocks (Expr* expr) { SimplifyBlocksMutator ()(expr); }
483
-
484
522
} // namespace optim
485
523
} // namespace cinn
0 commit comments