Skip to content

Commit 2bcf3f2

Browse files
Ziheng Jiangtqchen
authored andcommitted
fix Stage.fuse (#33)
1 parent e42cc11 commit 2bcf3f2

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

src/api/api_lang.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ TVM_REGISTER_API(_StageFuse)
216216
.set_body([](TVMArgs args, TVMRetValue* ret) {
217217
IterVar fused;
218218
args[0].operator Stage()
219-
.split(args[1], args[2], &fused);
219+
.fuse(args[1], args[2], &fused);
220220
*ret = fused;
221221
});
222222

src/schedule/schedule_lang.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor
117117

118118
Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*)
119119
IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
120+
*p_target = fused;
120121
StageNode* self = operator->();
121122
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
122123
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
@@ -129,7 +130,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT
129130
CHECK_EQ(pos_inner, pos_outer + 1)
130131
<< "Can only fuse iterations that are consecutive between each other";
131132
leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer,
132-
leaf_vars->data.begin() + pos_inner);
133+
leaf_vars->data.begin() + pos_inner + 1);
133134
leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer,
134135
fused.node_);
135136
return *this;

tests/python/unittest/test_lang_schedule.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,23 @@ def test_tile():
6363
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
6464
assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi)
6565

66+
67+
def test_fuse():
68+
m = tvm.Var('m')
69+
n = tvm.Var('n')
70+
A = tvm.placeholder((m, n), name='A')
71+
T = tvm.compute((m, n), lambda i, j: A[i, j])
72+
73+
s = tvm.Schedule(T.op)
74+
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
75+
fused = s[T].fuse(yo, xo)
76+
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
77+
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
78+
79+
6680
if __name__ == "__main__":
6781
test_schedule_create()
6882
test_reorder()
6983
test_tile()
7084
test_split()
85+
test_fuse()

0 commit comments

Comments
 (0)