Skip to content

Commit 40684dc

Browse files
author
Umang Yadav
committed
Add floordiv for the deduce bound
Use fdiv in the tests for the deduce_bound
1 parent f5f2fee commit 40684dc

File tree

2 files changed

+17
-34
lines changed

2 files changed

+17
-34
lines changed

src/arithmetic/bound_deducer.cc

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -150,43 +150,26 @@ class BoundDeducer: public IRVisitor {
150150
// always use relax bound
151151
bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
152152

153-
// TODO(tvm-team): use floordiv, which could give better bound.
154-
result_ = truncdiv(result_, operand);
153+
result_ = floordiv(result_, operand); // rounding down here
155154

156155
if (!divided) {
157-
// Handle non-divisible case
158-
// NOTE: this accounts for trunc div behavior.
159-
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
160-
161156
if (comp_op == kGreater) {
157+
// System will round down in all the cases, so add one for result_ for kGreater
158+
// (x >= 3/2 --> x >= 2)
159+
// (x >= -3/2 --> x >= -1)
160+
// (x >= 3/-2 --> x >= -1)
161+
// (x >= -3/-2 --> x >= 2)
162162
result_ += 1;
163163
} else if (comp_op == kEqual) {
164-
// condition unsatisfiable as with trunc div, it will change the expression
164+
// condition unsatisfiable as with floor div, it will change the expression
165165
success_ = false;
166166
return;
167167
} else {
168-
// NOTE: this is a bit sutble hack.
169-
//
170-
// condition:
171-
// - x * operand <= result
172-
// - operand > 0
173-
// - x >= 0
174-
//
175-
// Then it is fine to deduce that x <= result / operand.
176-
// - if result > 0, this division round down
177-
// - if result < 0, (result / operand) rounds up and may violate the constraint
178-
// however, given that x is always non-negative,
179-
// it is fine to have this relaxed bound, given that the user of deduce bound
180-
// will respect the bound of x
181-
//
182-
// TODO(tvm-team): think about a better API to incorporate constraint of x.
183-
// e.g. specify an interval of x and return a bound
184-
// that is in the interval and satisfies the condition.
185-
if (target_is_non_neg && sign_operand == kPositive) {
186-
// do nothing
187-
} else {
188-
result_ -= 1;
189-
}
168+
// System rounds down in all cases, do nothing for kLess.
169+
// ( x <= 3/2 --> x <= 1)
170+
// ( x <= -3/2 --> x <= -2)
171+
// ( x <= 3/-2 --> x <= -2)
172+
// ( x <= -3/-2 --> x <= 1)
190173
}
191174
}
192175
Visit(left ? op->a : op->b);

tests/python/unittest/test_arith_deduce_bound.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def test_deduce():
3535
d_s = tvm.arith.IntervalSet(-3, -1)
3636
zero = tvm.const(0, "int32")
3737

38-
tdiv = tvm.truncdiv
38+
fdiv = tvm.floordiv
3939

4040
e0 = (-b)*a+c-d
4141
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
42-
ans0 = (tdiv(d - c, b*-1) + (-1))
42+
ans0 = fdiv(d - c, b*-1)
4343
assert_expr_equal(res0.max_value, ans0)
4444

4545
# expression containing variable a is on rhs
@@ -48,7 +48,7 @@ def test_deduce():
4848

4949
e0 = d*a+c-d
5050
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
51-
ans0 = (tdiv(d-c,d) - 1)
51+
ans0 = fdiv(d-c, d)
5252
assert_expr_equal(res0.max_value, ans0)
5353

5454
# expression containing variable a is on rhs
@@ -58,7 +58,7 @@ def test_deduce():
5858

5959
e1 = (a*4+b < c)
6060
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
61-
ans1 = (tdiv((c - b) + -1,4) -1)
61+
ans1 = fdiv(c-1-b, 4)
6262
assert_expr_equal(res1.max_value, ans1)
6363

6464

@@ -81,7 +81,7 @@ def test_deduce():
8181

8282
e3 = (-b)+a*c-d
8383
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
84-
ans3 = tdiv(2,c)+1
84+
ans3 = fdiv(2,c)+1
8585
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
8686

8787
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})

0 commit comments

Comments
 (0)