Skip to content

Commit 3555769

Browse files
Ziheng Jiangtqchen
authored andcommitted
[ARITH] Add CombineInterval<Div> in IntSet (#48)
* [FIX] add CombineInterval<Div> * fix error message and add comment about rounding * fix comment
1 parent c8ec411 commit 3555769

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

src/arithmetic/int_set.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
244244
if (is_one(b.min)) return IntervalSet::make(a);
245245
Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
246246
Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
247-
// This is relaxiation
247+
// no relaxation is needed in here due to set is inclusive
248248
// TODO(tqchen): consider convert to StrideSet.
249249
if (is_positive_const(b.min)) {
250250
return IntervalSet::make(e1, e2);
@@ -259,6 +259,32 @@ inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
259259
return IntSet::everything();
260260
}
261261

262+
template<>
263+
inline IntSet CombineInterval<Div>(Interval a, Interval b) {
264+
if (a.is_single_point() && b.is_single_point()) {
265+
return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
266+
}
267+
if (b.is_single_point()) {
268+
if (is_zero(b.min)) {
269+
LOG(FATAL) << "Divide by zero in CombineInterval Div";
270+
}
271+
if (is_one(b.min)) return IntervalSet::make(a);
272+
Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
273+
Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
274+
// no relaxation is needed in here due to set is inclusive
275+
if (is_positive_const(b.min)) {
276+
return IntervalSet::make(e1, e2);
277+
} else if (is_negative_const(b.min)) {
278+
return IntervalSet::make(e2, e1);
279+
} else if (a.is_bounded()) {
280+
Expr cmp = b.min >= make_zero(b.min.type().element_of());
281+
return IntervalSet::make(select(cmp, e1, e2), select(cmp, e2, e1));
282+
}
283+
}
284+
LOG(WARNING) << "Return Everything in CombineInterval Div";
285+
return IntSet::everything();
286+
}
287+
262288
template<>
263289
inline IntSet CombineInterval<Max>(Interval a, Interval b) {
264290
if (a.is_single_point() && b.is_single_point()) {

0 commit comments

Comments
 (0)