Skip to content

Commit f5f2fee

Browse files
authored
[ARITH] migrate indexdiv/mod to floordiv/mod (#4008)
1 parent 2dac17d commit f5f2fee

File tree

9 files changed

+38
-19
lines changed

9 files changed

+38
-19
lines changed

python/tvm/expr.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,13 @@ def __rtruediv__(self, other):
9292
return _generic.divide(other, self)
9393

9494
def __floordiv__(self, other):
95-
# return _generic.floordiv(self, other)
96-
return _generic.divide(self, other)
95+
return _generic.floordiv(self, other)
9796

9897
def __rfloordiv__(self, other):
99-
# return _generic.floordiv(other, self)
100-
return _generic.divide(other, self)
98+
return _generic.floordiv(other, self)
10199

102100
def __mod__(self, other):
103-
raise div_ambiguity_error()
104-
# return _make._OpMod(self, other)
101+
return _make._OpFloorMod(self, other)
105102

106103
def __neg__(self):
107104
neg_one = _api_internal._const(-1, self.dtype)

src/lang/attr_functor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -87,6 +87,8 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
8787
virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
8888
virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
8989
virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
90+
virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
91+
virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
9092
virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
9193
virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
9294
virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
@@ -119,6 +121,9 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
119121
ATTR_FUNCTOR_DISPATCH(Sub);
120122
ATTR_FUNCTOR_DISPATCH(Mul);
121123
ATTR_FUNCTOR_DISPATCH(Div);
124+
ATTR_FUNCTOR_DISPATCH(Mod);
125+
ATTR_FUNCTOR_DISPATCH(FloorDiv);
126+
ATTR_FUNCTOR_DISPATCH(FloorMod);
122127
ATTR_FUNCTOR_DISPATCH(Min);
123128
ATTR_FUNCTOR_DISPATCH(Max);
124129
ATTR_FUNCTOR_DISPATCH(GE);
@@ -160,6 +165,8 @@ class AttrsEqualHandler :
160165
bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final;
161166
bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final;
162167
bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final;
168+
bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final;
169+
bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final;
163170
bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final;
164171
bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final;
165172
bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final;
@@ -201,6 +208,8 @@ class AttrsHashHandler :
201208
size_t VisitAttr_(const ir::Mul* op) final;
202209
size_t VisitAttr_(const ir::Div* op) final;
203210
size_t VisitAttr_(const ir::Mod* op) final;
211+
size_t VisitAttr_(const ir::FloorDiv* op) final;
212+
size_t VisitAttr_(const ir::FloorMod* op) final;
204213
size_t VisitAttr_(const ir::Min* op) final;
205214
size_t VisitAttr_(const ir::Max* op) final;
206215
size_t VisitAttr_(const ir::GE* op) final;

src/lang/attrs.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
154154
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
155155
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
156156
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
157+
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
158+
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
157159
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
158160
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
159161
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
@@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
266268
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
267269
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
268270
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
271+
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
272+
TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
269273
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
270274
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
271275
TVM_DEFINE_ATTRS_BINOP_HASH(GE);

src/lang/buffer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
namespace tvm {
3333

3434
// TODO(tqchen): change to floormod/div
35-
using IndexMod = ir::Mod;
36-
using IndexDiv = ir::Div;
35+
using IndexMod = ir::FloorMod;
36+
using IndexDiv = ir::FloorDiv;
3737

3838
Array<Expr> SimplifyArray(Array<Expr> array) {
3939
for (size_t i = 0; i < array.size(); ++i) {

src/lang/expr_operator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) {
208208

209209
// TODO(tqchen): switch to floordiv
210210
Expr indexdiv(Expr a, Expr b) {
211-
return truncdiv(a, b);
211+
return floordiv(a, b);
212212
}
213213

214214
Expr indexmod(Expr a, Expr b) {
215-
return truncmod(a, b);
215+
return floormod(a, b);
216216
}
217217

218218
Expr floordiv(Expr a, Expr b) {

src/pass/lower_intrin.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
4646
patterns_.push_back("tvm.intrin.rule." + starget + ".");
4747
patterns_.push_back("tvm.intrin.rule.default.");
4848
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
49+
if (target == "stackvm") {
50+
support_bitwise_op_ = false;
51+
}
4952
}
5053

5154
Expr Mutate_(const Call* op, const Expr& e) final {
@@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
7679
const DataType& dtype = op->type;
7780
CHECK(dtype.is_int() || !dtype.is_uint());
7881

79-
if (is_const_power_of_two_integer(op->b, &shift)) {
82+
if (support_bitwise_op_ &&
83+
is_const_power_of_two_integer(op->b, &shift)) {
8084
// lower to right shift if possible.
8185
return op->a >> make_const(dtype, shift);
8286
}
@@ -93,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
9397
// condition on b >= 0.
9498
// truncmod(a, b) < 0 will implies ceildiv,
9599
// So we need to correct these cases.
96-
if (dtype == Int(32) || dtype == Int(64)) {
100+
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
97101
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
98102
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
99103
} else {
@@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
122126
const DataType& dtype = op->type;
123127
CHECK(dtype.is_int() || !dtype.is_uint());
124128

125-
if (is_const_power_of_two_integer(op->b, &shift)) {
129+
if (support_bitwise_op_ &&
130+
is_const_power_of_two_integer(op->b, &shift)) {
126131
// lower to masking if possible.
127132
int64_t mask = (
128133
static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
@@ -140,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
140145
// mod(a, b) < 0 will imply we are doing ceildiv,
141146
// So we need to correct these cases.
142147
Expr rmod = truncmod(op->a, op->b);
143-
if (dtype == Int(32) || dtype == Int(64)) {
148+
if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
144149
// (rmod >> shift) & b
145150
// -> (rmod >= 0 ? 0: -1) & b
146151
// -> rmod >= 0 ? 0 : b
@@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
268273
// patterns
269274
std::vector<std::string> patterns_;
270275
const PackedFunc* fma_{nullptr};
276+
bool support_bitwise_op_{true};
271277
};
272278

273279
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {

tests/python/unittest/test_codegen_device.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def test_add_pipeline():
4848
stmt = tvm.ir_pass.Simplify(stmt)
4949
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
5050
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
51+
# lower the floordiv(use stackvm rules so it works for all targets)
52+
fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
5153
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
5254

5355
def check_target(device, host="stackvm"):

tests/python/unittest/test_codegen_vm_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def tvm_call_back_get_shape(shape0):
3737
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
3838
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
3939
fapi = tvm.ir_pass.LowerTVMBuiltin(fapi)
40+
fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm")
4041
run_jit(fapi, lambda f: f(a))
4142

4243

topi/python/topi/cuda/nms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial):
185185
ib.scope_attr(bx, "thread_extent", nthread_bx)
186186
var = tvm.make.node("FloatImm", dtype="float32", value=2)
187187
new_range = num_anchors // elem_per_thread + 1
188-
iteration = log(cast(new_range, "float32")) // math.log(2)
188+
iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
189189
# Scan: Kogge-Stone adder
190190
with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
191191
with ib.for_range(0, iteration) as k:

0 commit comments

Comments
 (0)