Skip to content

Commit cf26ac9

Browse files
Sen Yangylc
authored andcommitted
[Bugfix] Fix div zero error in rewrite_simplify (apache#8961)
* fix div zero error in rewrite_simplify * update the style to fix ci error * remove useless code and comment
1 parent e7847ca commit cf26ac9

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
474474
if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
475475
int64_t c1val = c1.Eval()->value;
476476
int64_t c2val = c2.Eval()->value;
477+
ICHECK(c2val != 0) << "division by zero";
477478
if (c1val % c2val == 0) {
478479
return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
479480
}

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import pytest
1718
import tvm
1819
from tvm import te
1920

@@ -931,20 +932,13 @@ def test_shift_left_simplify():
931932
ck.verify(z, tvm.tir.const(1 << 10, "int32"))
932933

933934

935+
def test_div_zero_simplify():
936+
ck = RewriteChecker()
937+
938+
with pytest.raises(tvm.error.TVMError) as cm:
939+
ck.analyzer.rewrite_simplify(tvm.tir.Div(tvm.tir.Ramp(1, 1, 2), tvm.tir.Broadcast(0, 2)))
940+
assert "division by zero" in str(cm.execption)
941+
942+
934943
if __name__ == "__main__":
935-
test_floordiv_index_simplify()
936-
test_floormod_index_simplify()
937-
test_cmp_simplify()
938-
test_vector_simplify()
939-
test_add_index_simplify()
940-
test_sub_index_simplify()
941-
test_mul_index_simplify()
942-
test_div_index_simplify()
943-
test_max_index_simplify()
944-
test_min_index_simplify()
945-
test_mod_index_simplify()
946-
test_select_simplify()
947-
test_logical_simplify()
948-
test_let_simplify()
949-
test_cast_simplify()
950-
test_shift_left_simplify()
944+
pytest.main([__file__])

0 commit comments

Comments
 (0)