Skip to content

Commit f631fb4

Browse files
authored
[SCHEDULE] Fix inline with multiple outputs (#507)
1 parent af8cbdd commit f631fb4

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

src/op/compute_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Array<Tensor> compute(Array<Expr> shape,
9898
return outputs;
9999
}
100100

101-
bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
101+
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
102102
return (a->combiner.same_as(b->combiner)) &&
103103
(a->source.same_as(b->source)) &&
104104
(a->axis.same_as(b->axis)) &&

src/schedule/schedule_dataflow_rewrite.cc

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,25 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
275275
}
276276
}
277277

278+
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
279+
return (a->combiner.same_as(b->combiner)) &&
280+
(a->source.same_as(b->source)) &&
281+
(a->axis.same_as(b->axis)) &&
282+
(a->condition.same_as(b->condition));
283+
}
284+
278285
void InjectInline(ScheduleNode* sch) {
279286
sch->InvalidateCache();
280287

281-
std::vector<Array<Expr>> new_body(sch->stages.size());
288+
std::vector<Array<Expr> > new_body(sch->stages.size());
282289
std::vector<bool> changed(sch->stages.size(), false);
283290
// inline all the ops
284291
for (size_t i = sch->stages.size(); i != 0; --i) {
285292
Stage stage = sch->stages[i - 1];
286293
if (stage->attach_type == kInline) {
287294
stage->attach_type = kInlinedAlready;
288295
Array<Var> args;
289-
Array<Expr> body;
296+
Expr body;
290297
{
291298
// setup args
292299
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
@@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) {
295302
for (auto iv : compute->axis) {
296303
args.push_back(iv->var);
297304
}
298-
body = compute->body;
305+
CHECK_EQ(compute->body.size(), 1U)
306+
<< "can only inline compute op with 1 output";
307+
body = compute->body[0];
299308
}
300309
for (size_t j = i; j < sch->stages.size(); ++j) {
301310
Stage s = sch->stages[j];
@@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) {
304313
if (!new_body[j].size()) {
305314
new_body[j] = s->op.as<ComputeOpNode>()->body;
306315
}
307-
for (size_t k = 0; k < body.size(); ++k) {
308-
changed[j] = true;
309-
new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]),
310-
stage->op, args, body[k]).as<ir::Evaluate>()->value);
316+
if (new_body[j][0]->is_type<ir::Reduce>()) {
317+
// specially handle reduction inline for multiplre reductions.
318+
const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
319+
for (size_t k = 1; k < new_body[j].size(); ++k) {
320+
const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
321+
CHECK(reduce_);
322+
CHECK(ReduceEqual(reduce_, reduce))
323+
<< "The Reduce inputs of ComputeOp should "
324+
<< "have the same attribute except value_index";
325+
}
326+
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
327+
stage->op, args, body).as<ir::Evaluate>()->value;
328+
if (!new_value.same_as(new_body[j][0])) {
329+
changed[j] = true;
330+
const ir::Reduce* r = new_value.as<ir::Reduce>();
331+
CHECK_EQ(new_body[j].size(), r->source.size());
332+
CHECK(r != nullptr);
333+
for (size_t k = 0; k < new_body[j].size(); ++k) {
334+
std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r);
335+
n->value_index = static_cast<int>(k);
336+
n->type = r->source[k].type();
337+
new_body[j].Set(k, Expr(n));
338+
}
339+
}
340+
} else {
341+
for (size_t k = 0; k < new_body[j].size(); ++k) {
342+
Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
343+
stage->op, args, body).as<ir::Evaluate>()->value;
344+
if (!new_value.same_as(new_body[j][k])) {
345+
new_body[j].Set(k, new_value);
346+
changed[j] = true;
347+
}
348+
}
311349
}
312350
}
313351
}

tests/python/unittest/test_schedule_schedule_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ def test_schedule_scan():
5656
assert(bounds[res.op.scan_axis].min.value == 1)
5757
stmt = tvm.schedule.ScheduleOps(s, bounds)
5858

59+
def test_inline_multi_reduce():
60+
def argmax_comp(x, y):
61+
idx = tvm.select((x[1] >= y[1]), x[0], y[0])
62+
val = tvm.select((x[1] >= y[1]), x[1], y[1])
63+
return idx, val
64+
def argmax_init(idx_typ, val_typ):
65+
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
66+
67+
argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax')
68+
m = tvm.var('m')
69+
n = tvm.var('n')
70+
val = tvm.placeholder((m, n), name='val', dtype='float32')
71+
val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2')
72+
k = tvm.reduce_axis((0, n), 'k')
73+
T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T')
74+
s = tvm.create_schedule(T_idx.op)
75+
s[val2].compute_inline()
76+
s = s.normalize()
77+
bounds = tvm.schedule.InferBound(s)
78+
stmt = tvm.schedule.ScheduleOps(s, bounds)
79+
80+
5981
def test_auto_inline():
6082
m = tvm.var('m')
6183
n = tvm.var('n')
@@ -207,6 +229,7 @@ def test_schedule_cache_relayout3():
207229

208230

209231
if __name__ == "__main__":
232+
test_inline_multi_reduce()
210233
test_schedule_cache_relayout3()
211234
test_schedule_cache_relayout2()
212235
test_schedule_cache_relayout1()

0 commit comments

Comments
 (0)