@@ -275,18 +275,25 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
275275 }
276276}
277277
278+ inline bool ReduceEqual (const ir::Reduce* a, const ir::Reduce* b) {
279+ return (a->combiner .same_as (b->combiner )) &&
280+ (a->source .same_as (b->source )) &&
281+ (a->axis .same_as (b->axis )) &&
282+ (a->condition .same_as (b->condition ));
283+ }
284+
278285void InjectInline (ScheduleNode* sch) {
279286 sch->InvalidateCache ();
280287
281- std::vector<Array<Expr>> new_body (sch->stages .size ());
288+ std::vector<Array<Expr> > new_body (sch->stages .size ());
282289 std::vector<bool > changed (sch->stages .size (), false );
283290 // inline all the ops
284291 for (size_t i = sch->stages .size (); i != 0 ; --i) {
285292 Stage stage = sch->stages [i - 1 ];
286293 if (stage->attach_type == kInline ) {
287294 stage->attach_type = kInlinedAlready ;
288295 Array<Var> args;
289- Array< Expr> body;
296+ Expr body;
290297 {
291298 // setup args
292299 const ComputeOpNode* compute = stage->op .as <ComputeOpNode>();
@@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) {
295302 for (auto iv : compute->axis ) {
296303 args.push_back (iv->var );
297304 }
298- body = compute->body ;
305+ CHECK_EQ (compute->body .size (), 1U )
306+ << " can only inline compute op with 1 output" ;
307+ body = compute->body [0 ];
299308 }
300309 for (size_t j = i; j < sch->stages .size (); ++j) {
301310 Stage s = sch->stages [j];
@@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) {
304313 if (!new_body[j].size ()) {
305314 new_body[j] = s->op .as <ComputeOpNode>()->body ;
306315 }
307- for (size_t k = 0 ; k < body.size (); ++k) {
308- changed[j] = true ;
309- new_body[j].Set (k, ir::Inline (ir::Evaluate::make (new_body[j][k]),
310- stage->op , args, body[k]).as <ir::Evaluate>()->value );
316+ if (new_body[j][0 ]->is_type <ir::Reduce>()) {
317+ // specially handle reduction inline for multiplre reductions.
318+ const ir::Reduce* reduce = new_body[j][0 ].as <ir::Reduce>();
319+ for (size_t k = 1 ; k < new_body[j].size (); ++k) {
320+ const ir::Reduce* reduce_ = new_body[j][k].as <ir::Reduce>();
321+ CHECK (reduce_);
322+ CHECK (ReduceEqual (reduce_, reduce))
323+ << " The Reduce inputs of ComputeOp should "
324+ << " have the same attribute except value_index" ;
325+ }
326+ Expr new_value = ir::Inline (ir::Evaluate::make (new_body[j][0 ]),
327+ stage->op , args, body).as <ir::Evaluate>()->value ;
328+ if (!new_value.same_as (new_body[j][0 ])) {
329+ changed[j] = true ;
330+ const ir::Reduce* r = new_value.as <ir::Reduce>();
331+ CHECK_EQ (new_body[j].size (), r->source .size ());
332+ CHECK (r != nullptr );
333+ for (size_t k = 0 ; k < new_body[j].size (); ++k) {
334+ std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
335+ n->value_index = static_cast <int >(k);
336+ n->type = r->source [k].type ();
337+ new_body[j].Set (k, Expr (n));
338+ }
339+ }
340+ } else {
341+ for (size_t k = 0 ; k < new_body[j].size (); ++k) {
342+ Expr new_value = ir::Inline (ir::Evaluate::make (new_body[j][k]),
343+ stage->op , args, body).as <ir::Evaluate>()->value ;
344+ if (!new_value.same_as (new_body[j][k])) {
345+ new_body[j].Set (k, new_value);
346+ changed[j] = true ;
347+ }
348+ }
311349 }
312350 }
313351 }
0 commit comments