Skip to content

Commit 39ff2ca

Browse files
committed
fallback ceildiv changes
1 parent 6dc7956 commit 39ff2ca

File tree

1 file changed

+7
-38
lines changed

1 file changed

+7
-38
lines changed

src/transform/lower_intrin.cc

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include <tvm/tir/op.h>
2929
#include <tvm/tir/transform.h>
3030

31-
#include <functional>
3231
#include <limits>
3332
#include <unordered_set>
3433

@@ -37,7 +36,6 @@
3736

3837
namespace tvm {
3938
namespace tl {
40-
4139
using namespace tir;
4240

4341
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
@@ -110,38 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
110108
const DataType &dtype = op->dtype;
111109
ICHECK(dtype.is_int() || dtype.is_uint());
112110

113-
auto is_ceildiv_numerator = [&]() {
114-
std::vector<std::pair<PrimExpr, int>> terms;
115-
std::function<void(const PrimExpr &, int)> collect =
116-
[&](const PrimExpr &expr, int sign) {
117-
if (const auto *add = expr.as<AddNode>()) {
118-
collect(add->a, sign);
119-
collect(add->b, sign);
120-
} else if (const auto *sub = expr.as<SubNode>()) {
121-
collect(sub->a, sign);
122-
collect(sub->b, -sign);
123-
} else {
124-
terms.emplace_back(expr, sign);
125-
}
126-
};
127-
collect(op->a, 1);
128-
int den_coeff = 0;
129-
int64_t const_term = 0;
130-
for (const auto &term : terms) {
131-
const PrimExpr &expr = term.first;
132-
int sign = term.second;
133-
if (const auto *imm = expr.as<IntImmNode>()) {
134-
const_term += static_cast<int64_t>(sign) * imm->value;
135-
} else if (analyzer_->CanProveEqual(expr, op->b)) {
136-
den_coeff += sign;
137-
}
138-
}
139-
return den_coeff == 1 && const_term == -1;
140-
}();
141-
142-
if (support_bitwise_op_ && !is_ceildiv_numerator &&
143-
is_const_power_of_two_integer(op->b, &shift) &&
144-
analyzer_->CanProveGreaterEqual(op->a, 0)) {
111+
// lower (a + 31) // 512 to (a + 31) >> 5
112+
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
145113
// lower to right shift if possible.
146114
return op->a >> make_const(dtype, shift);
147115
}
@@ -192,8 +160,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
192160
// == truncdiv(a + b*c, b) - c
193161
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
194162
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
195-
PrimExpr offset_numerator =
196-
analyzer_->Simplify(op->a + op->b * ceildiv);
163+
// Skip analyzer simplification so we preserve straightforward div
164+
// expressions.
165+
PrimExpr offset_numerator = op->a + op->b * ceildiv;
197166
return truncdiv(offset_numerator, op->b) - ceildiv;
198167
}
199168

@@ -429,9 +398,9 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string &target) {
429398
}
430399

431400
namespace transform {
432-
using namespace tir::transform;
433401

434-
tvm::transform::Pass LowerIntrin() {
402+
tir::transform::Pass LowerIntrin() {
403+
using namespace tir::transform;
435404
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
436405
auto *n = f.CopyOnWrite();
437406
auto target = f->GetAttr<Target>(tvm::attr::kTarget);

0 commit comments

Comments
 (0)