Skip to content

Commit 9ee7dbd

Browse files
bulanova-huawei366anna
authored andcommitted
tightening bounding box for IntSet fused in PassUpDomain
1 parent 8471f81 commit 9ee7dbd

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

src/schedule/message_passing.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,24 @@ void PassUpDomain(const FuseNode* s,
270270
*outer = IntSet::single_point(v_outer);
271271
*inner = IntSet::single_point(v_inner);
272272
} else {
273-
LOG(WARNING) << "use fallback inference rule in fuse";
274-
// simply use the entire set, this rule can be enhanced.
275-
*outer = IntSet::range(dom_map.at(s->outer));
276-
*inner = IntSet::range(dom_map.at(s->inner));
273+
Expr fused_extent = (fused.max() - fused.min() + 1);
274+
Expr inner_extent = dom_map.at(s->inner)->extent;
275+
*outer = IntSet::interval(outer_min + fused.min() / inner_extent,
276+
outer_min + fused.max() / inner_extent);
277+
if ( is_zero(Simplify(inner_extent % fused_extent)) &&
278+
is_zero(Simplify(fused.min() % fused_extent)) ) {
279+
// fused never spans multiple rows, make a tight bounding box
280+
// there may be other cases when bounding box could be tightened
281+
*inner = IntSet::interval(inner_min + fused.min() % inner_extent,
282+
inner_min + fused.max() % inner_extent);
283+
} else { // fused may span multiple rows, use full row widths
284+
if (!is_zero(Simplify(fused_extent % inner_extent)) ||
285+
!is_zero(Simplify(fused.min() % inner_extent))) {
286+
LOG(WARNING) <<
287+
"fused and original axes are not aligned, this may cause redundant computations";
288+
}
289+
*inner = IntSet::range(dom_map.at(s->inner));
290+
}
277291
return;
278292
}
279293
}

tests/python/unittest/test_schedule_bound_inference.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,46 @@ def test_bound3():
6969
assert(bounds[A1.op.axis[0]].extent.value==32)
7070
assert(bounds[A1.op.axis[1]].extent.value==16)
7171

72+
def test_bound_fusesplit1():
73+
m = tvm.var('m')
74+
l = tvm.var('l')
75+
split = tvm.var('s')
76+
A = tvm.placeholder((m, l), name='A')
77+
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
78+
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
79+
80+
s = tvm.create_schedule(A2.op)
81+
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
82+
xo, xi = s[A2].split(fused_axes, split)
83+
s[A1].compute_at(s[A2], xo)
84+
85+
bounds = tvm.schedule.InferBound(s)
86+
assert isinstance(bounds, tvm.container.Map)
87+
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split) / l ).value == 0)
88+
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].extent - (((xo + 1) * split - 1) / l - (xo * split) / l + 1)).value == 0)
89+
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)
90+
91+
def test_bound_fusesplit2():
92+
m = tvm.var("m")
93+
l = tvm.convert(6)
94+
split = tvm.convert(3)
95+
A = tvm.placeholder((m, l), name='A')
96+
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
97+
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
98+
99+
s = tvm.create_schedule(A2.op)
100+
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
101+
xo, xi = s[A2].split(fused_axes, split)
102+
s[A1].compute_at(s[A2], xo)
103+
104+
bounds = tvm.schedule.InferBound(s)
105+
assert isinstance(bounds, tvm.container.Map)
106+
vars = tvm.convert({xo.var: tvm.const(5, "int32")})
107+
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2)
108+
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3)
109+
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1)
110+
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3)
111+
72112

73113
def test_bound_warp():
74114
m = tvm.var('m')
@@ -320,3 +360,5 @@ def _body():
320360
test_gemm_bound()
321361
test_bound_warp()
322362
test_bound_tensor_compute_op()
363+
test_bound_fusesplit1()
364+
test_bound_fusesplit2()

0 commit comments

Comments
 (0)